Implement bags rewriter (#5132)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Mon, 28 Sep 2020 13:53:07 +0000 (08:53 -0500)
committerGitHub <noreply@github.com>
Mon, 28 Sep 2020 13:53:07 +0000 (08:53 -0500)
This PR implements rewrite rules for bags. This PR focuses on rewrite rules for non constant nodes.
Rewriting nodes with constant children is delegated to bags::NormalForm class (future PR).

18 files changed:
src/CMakeLists.txt
src/theory/bags/bags_rewriter.cpp [new file with mode: 0644]
src/theory/bags/bags_rewriter.h [new file with mode: 0644]
src/theory/bags/bags_statistics.cpp [new file with mode: 0644]
src/theory/bags/bags_statistics.h [new file with mode: 0644]
src/theory/bags/kinds
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/rewrites.cpp [new file with mode: 0644]
src/theory/bags/rewrites.h [new file with mode: 0644]
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/bags/theory_bags_rewriter.cpp [deleted file]
src/theory/bags/theory_bags_rewriter.h [deleted file]
src/theory/bags/theory_bags_type_enumerator.h
src/theory/bags/theory_bags_type_rules.h
test/unit/theory/CMakeLists.txt
test/unit/theory/theory_bags_rewriter_black.h [new file with mode: 0644]

index 717378b27941277ab2e3fecc7f3044bbb49e3ebb..74dcc39b335c9b231b7d4da91f81596578d6d7ef 100644 (file)
@@ -404,18 +404,22 @@ libcvc4_add_sources(
   theory/assertion.h
   theory/atom_requests.cpp
   theory/atom_requests.h
+  theory/bags/bags_rewriter.cpp
+  theory/bags/bags_rewriter.h
+  theory/bags/bags_statistics.cpp
+  theory/bags/bags_statistics.h
   theory/bags/inference_manager.cpp
   theory/bags/inference_manager.h
   theory/bags/normal_form.cpp
   theory/bags/normal_form.h
+  theory/bags/rewrites.cpp
+  theory/bags/rewrites.h
   theory/bags/solver_state.cpp
   theory/bags/solver_state.h
   theory/bags/term_registry.cpp
   theory/bags/term_registry.h
   theory/bags/theory_bags.cpp
   theory/bags/theory_bags.h
-  theory/bags/theory_bags_rewriter.cpp
-  theory/bags/theory_bags_rewriter.h
   theory/bags/theory_bags_type_enumerator.cpp
   theory/bags/theory_bags_type_enumerator.h
   theory/bags/theory_bags_type_rules.h
diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp
new file mode 100644 (file)
index 0000000..a6506f1
--- /dev/null
@@ -0,0 +1,434 @@
+/*********************                                                        */
+/*! \file bags_rewriter.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Bags theory rewriter.
+ **/
+
+#include "theory/bags/bags_rewriter.h"
+
+#include "normal_form.h"
+
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+BagsRewriteResponse::BagsRewriteResponse()
+    : d_node(Node::null()), d_rewrite(Rewrite::NONE)
+{
+}
+
+BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
+    : d_node(n), d_rewrite(rewrite)
+{
+}
+
+BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
+    : d_node(r.d_node), d_rewrite(r.d_rewrite)
+{
+}
+
+BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
+    : d_statistics(statistics)
+{
+  d_nm = NodeManager::currentNM();
+}
+
+RewriteResponse BagsRewriter::postRewrite(TNode n)
+{
+  BagsRewriteResponse response;
+  if (n.isConst())
+  {
+    // no need to rewrite n if it is already in a normal form
+    response = BagsRewriteResponse(n, Rewrite::NONE);
+  }
+  else if (NormalForm::AreChildrenConstants(n))
+  {
+    Node value = NormalForm::evaluate(n);
+    response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
+  }
+  else
+  {
+    Kind k = n.getKind();
+    switch (k)
+    {
+      case MK_BAG: response = rewriteMakeBag(n); break;
+      case BAG_COUNT: response = rewriteBagCount(n); break;
+      case UNION_MAX: response = rewriteUnionMax(n); break;
+      case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
+      case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
+      case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
+      case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
+      case BAG_CHOOSE: response = rewriteChoose(n); break;
+      case BAG_CARD: response = rewriteCard(n); break;
+      case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
+      default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
+    }
+  }
+
+  Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
+                        << " by " << response.d_rewrite << "." << std::endl;
+
+  if (d_statistics != nullptr)
+  {
+    (*d_statistics) << response.d_rewrite;
+  }
+  if (response.d_node != n)
+  {
+    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
+  }
+  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
+}
+
+RewriteResponse BagsRewriter::preRewrite(TNode n)
+{
+  BagsRewriteResponse response;
+  Kind k = n.getKind();
+  switch (k)
+  {
+    case EQUAL: response = rewriteEqual(n); break;
+    case BAG_IS_INCLUDED: response = rewriteIsIncluded(n); break;
+    default: response = BagsRewriteResponse(n, Rewrite::NONE);
+  }
+
+  Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
+                        << " by " << response.d_rewrite << "." << std::endl;
+
+  if (d_statistics != nullptr)
+  {
+    (*d_statistics) << response.d_rewrite;
+  }
+  if (response.d_node != n)
+  {
+    return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
+  }
+  return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteEqual(const TNode& n) const
+{
+  Assert(n.getKind() == EQUAL);
+  if (n[0] == n[1])
+  {
+    // (= A A) = true where A is a bag
+    return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIsIncluded(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_IS_INCLUDED);
+
+  // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
+  Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
+  Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
+  Node equal = subtract.eqNode(emptybag);
+  return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
+{
+  Assert(n.getKind() == MK_BAG);
+  // return emptybag for negative or zero multiplicity
+  if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
+  {
+    // (mkBag x c) = emptybag where c <= 0
+    Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
+    return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_COUNT);
+  if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
+  {
+    // (bag.count x emptybag) = 0
+    return BagsRewriteResponse(d_nm->mkConst(Rational(0)),
+                               Rewrite::COUNT_EMPTY);
+  }
+  if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
+  {
+    // (bag.count x (mkBag x c) = c where c > 0 is a constant
+    return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
+{
+  Assert(n.getKind() == UNION_MAX);
+  if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
+  {
+    // (union_max A A) = A
+    // (union_max A emptybag) = A
+    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
+  }
+  if (n[0].getKind() == EMPTYBAG)
+  {
+    // (union_max emptybag A) = A
+    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
+  }
+
+  if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
+      && (n[0] == n[1][0] || n[0] == n[1][1]))
+  {
+    // (union_max A (union_max A B)) = (union_max A B)
+    // (union_max A (union_max B A)) = (union_max B A)
+    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+    return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
+  }
+
+  if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
+      && (n[0][0] == n[1] || n[0][1] == n[1]))
+  {
+    // (union_max (union_max A B) A)) = (union_max A B)
+    // (union_max (union_max B A) A)) = (union_max B A)
+    // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
+    // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
+    return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
+{
+  Assert(n.getKind() == UNION_DISJOINT);
+  if (n[1].getKind() == EMPTYBAG)
+  {
+    // (union_disjoint A emptybag) = A
+    return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
+  }
+  if (n[0].getKind() == EMPTYBAG)
+  {
+    // (union_disjoint emptybag A) = A
+    return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
+  }
+  if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
+      || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
+
+  {
+    // (union_disjoint (union_max A B) (intersection_min A B)) =
+    //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+    // check if the operands of union_max and intersection_min are the same
+    std::set<Node> left(n[0].begin(), n[0].end());
+    std::set<Node> right(n[0].begin(), n[0].end());
+    if (left == right)
+    {
+      Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
+      return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
+    }
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
+{
+  Assert(n.getKind() == INTERSECTION_MIN);
+  if (n[0].getKind() == EMPTYBAG)
+  {
+    // (intersection_min emptybag A) = emptybag
+    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
+  }
+  if (n[1].getKind() == EMPTYBAG)
+  {
+    // (intersection_min A emptybag) = emptybag
+    return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
+  }
+  if (n[0] == n[1])
+  {
+    // (intersection_min A A) = A
+    return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
+  }
+  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+  {
+    if (n[0] == n[1][0] || n[0] == n[1][1])
+    {
+      // (intersection_min A (union_disjoint A B)) = A
+      // (intersection_min A (union_disjoint B A)) = A
+      // (intersection_min A (union_max A B)) = A
+      // (intersection_min A (union_max B A)) = A
+      return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
+    }
+  }
+
+  if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
+  {
+    if (n[1] == n[0][0] || n[1] == n[0][1])
+    {
+      // (intersection_min (union_disjoint A B) A) = A
+      // (intersection_min (union_disjoint B A) A) = A
+      // (intersection_min (union_max A B) A) = A
+      // (intersection_min (union_max B A) A) = A
+      return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
+    }
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
+    const TNode& n) const
+{
+  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
+  {
+    // (difference_subtract A emptybag) = A
+    // (difference_subtract emptybag A) = emptybag
+    return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
+  }
+  if (n[0] == n[1])
+  {
+    // (difference_subtract A A) = emptybag
+    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+    return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
+  }
+
+  if (n[0].getKind() == UNION_DISJOINT)
+  {
+    if (n[1] == n[0][0])
+    {
+      // (difference_subtract (union_disjoint A B) A) = B
+      return BagsRewriteResponse(n[0][1],
+                                 Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
+    }
+    if (n[1] == n[0][1])
+    {
+      // (difference_subtract (union_disjoint B A) A) = B
+      return BagsRewriteResponse(n[0][0],
+                                 Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
+    }
+  }
+
+  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+  {
+    if (n[0] == n[1][0] || n[0] == n[1][1])
+    {
+      // (difference_subtract A (union_disjoint A B)) = emptybag
+      // (difference_subtract A (union_disjoint B A)) = emptybag
+      // (difference_subtract A (union_max A B)) = emptybag
+      // (difference_subtract A (union_max B A)) = emptybag
+      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
+    }
+  }
+
+  if (n[0].getKind() == INTERSECTION_MIN)
+  {
+    if (n[1] == n[0][0] || n[1] == n[0][1])
+    {
+      // (difference_subtract (intersection_min A B) A) = emptybag
+      // (difference_subtract (intersection_min B A) A) = emptybag
+      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+      return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
+    }
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
+{
+  Assert(n.getKind() == DIFFERENCE_REMOVE);
+
+  if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
+  {
+    // (difference_remove A emptybag) = A
+    // (difference_remove emptybag B) = emptybag
+    return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
+  }
+
+  if (n[0] == n[1])
+  {
+    // (difference_remove A A) = emptybag
+    Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+    return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
+  }
+
+  if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+  {
+    if (n[0] == n[1][0] || n[0] == n[1][1])
+    {
+      // (difference_remove A (union_disjoint A B)) = emptybag
+      // (difference_remove A (union_disjoint B A)) = emptybag
+      // (difference_remove A (union_max A B)) = emptybag
+      // (difference_remove A (union_max B A)) = emptybag
+      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
+    }
+  }
+
+  if (n[0].getKind() == INTERSECTION_MIN)
+  {
+    if (n[1] == n[0][0] || n[1] == n[0][1])
+    {
+      // (difference_remove (intersection_min A B) A) = emptybag
+      // (difference_remove (intersection_min B A) A) = emptybag
+      Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+      return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
+    }
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_CHOOSE);
+  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
+  {
+    // (bag.choose (mkBag x c)) = x where c is a constant > 0
+    return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_CARD);
+  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
+  {
+    // (bag.card (mkBag x c)) = c where c is a constant > 0
+    return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
+  }
+
+  if (n[0].getKind() == UNION_DISJOINT)
+  {
+    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
+    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
+    Node plus = d_nm->mkNode(PLUS, A, B);
+    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_IS_SINGLETON);
+  if (n[0].getKind() == MK_BAG)
+  {
+    // (bag.is_singleton (mkBag x c)) = (c == 1)
+    Node one = d_nm->mkConst(Rational(1));
+    Node equal = n[0][1].eqNode(one);
+    return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h
new file mode 100644 (file)
index 0000000..f0998a2
--- /dev/null
@@ -0,0 +1,195 @@
+/*********************                                                        */
+/*! \file bags_rewriter.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Bags theory rewriter.
+ **/
+
+#include "cvc4_private.h"
+#include "theory/bags/rewrites.h"
+
+#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
+#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
+
+#include "theory/rewriter.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/** a class represents the result of rewriting bag nodes */
+struct BagsRewriteResponse
+{
+  BagsRewriteResponse();
+  BagsRewriteResponse(Node n, Rewrite rewrite);
+  BagsRewriteResponse(const BagsRewriteResponse& r);
+  /** the rewritten node */
+  Node d_node;
+  /** type of rewrite used by bags */
+  Rewrite d_rewrite;
+
+}; /* struct BagsRewriteResponse */
+
+class BagsRewriter : public TheoryRewriter
+{
+ public:
+  BagsRewriter(HistogramStat<Rewrite>* statistics = nullptr);
+
+  /**
+   * postRewrite nodes with kinds: MK_BAG, BAG_COUNT, UNION_MAX, UNION_DISJOINT,
+   * INTERSECTION_MIN, DIFFERENCE_SUBTRACT, DIFFERENCE_REMOVE, BAG_CHOOSE,
+   * BAG_CARD, BAG_IS_SINGLETON.
+   * See the rewrite rules for these kinds below.
+   */
+  RewriteResponse postRewrite(TNode n) override;
+  /**
+   * preRewrite nodes with kinds: EQUAL, BAG_IS_INCLUDED.
+   * See the rewrite rules for these kinds below.
+   */
+  RewriteResponse preRewrite(TNode n) override;
+
+ private:
+  /**
+   * rewrites for n include:
+   * - (= A A) = true where A is a bag
+   */
+  BagsRewriteResponse rewriteEqual(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (bag.is_included A B) = ((difference_subtract A B) == emptybag)
+   */
+  BagsRewriteResponse rewriteIsIncluded(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (mkBag x 0) = (emptybag T) where T is the type of x
+   * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
+   *   constant
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteMakeBag(const TNode& n) const;
+  /**
+   * rewrites for n include:
+   * - (bag.count x emptybag) = 0
+   * - (bag.count x (mkBag x c) = c where c > 0 is a constant
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteBagCount(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (union_max A emptybag) = A
+   * - (union_max emptybag A) = A
+   * - (union_max A A) = A
+   * - (union_max A (union_max A B)) = (union_max A B)
+   * - (union_max A (union_max B A)) = (union_max B A)
+   * - (union_max (union_max A B) A) = (union_max A B)
+   * - (union_max (union_max B A) A) = (union_max B A)
+   * - (union_max A (union_disjoint A B)) = (union_disjoint A B)
+   * - (union_max A (union_disjoint B A)) = (union_disjoint B A)
+   * - (union_max (union_disjoint A B) A) = (union_disjoint A B)
+   * - (union_max (union_disjoint B A) A) = (union_disjoint B A)
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteUnionMax(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (union_disjoint A emptybag) = A
+   * - (union_disjoint emptybag A) = A
+   * - (union_disjoint (union_max A B) (intersection_min A B)) =
+   *         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+   * - other permutations of the above like swapping A and B, or swapping
+   *   intersection_min and union_max
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteUnionDisjoint(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (intersection_min A emptybag) = emptybag
+   * - (intersection_min emptybag A) = emptybag
+   * - (intersection_min A A) = A
+   * - (intersection_min A (union_disjoint A B)) = A
+   * - (intersection_min A (union_disjoint B A)) = A
+   * - (intersection_min (union_disjoint A B) A) = A
+   * - (intersection_min (union_disjoint B A) A) = A
+   * - (intersection_min A (union_max A B)) = A
+   * - (intersection_min A (union_max B A)) = A
+   * - (intersection_min (union_max A B) A) = A
+   * - (intersection_min (union_max B A) A) = A
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteIntersectionMin(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (difference_subtract A emptybag) = A
+   * - (difference_subtract emptybag A) = emptybag
+   * - (difference_subtract A A) = emptybag
+   * - (difference_subtract (union_disjoint A B) A) = B
+   * - (difference_subtract (union_disjoint B A) A) = B
+   * - (difference_subtract A (union_disjoint A B)) = emptybag
+   * - (difference_subtract A (union_disjoint B A)) = emptybag
+   * - (difference_subtract A (union_max A B)) = emptybag
+   * - (difference_subtract A (union_max B A)) = emptybag
+   * - (difference_subtract (intersection_min A B) A) = emptybag
+   * - (difference_subtract (intersection_min B A) A) = emptybag
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteDifferenceSubtract(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (difference_remove A emptybag) = A
+   * - (difference_remove emptybag A) = emptybag
+   * - (difference_remove A A) = emptybag
+   * - (difference_remove A (union_disjoint A B)) = emptybag
+   * - (difference_remove A (union_disjoint B A)) = emptybag
+   * - (difference_remove A (union_max A B)) = emptybag
+   * - (difference_remove A (union_max B A)) = emptybag
+   * - (difference_remove (intersection_min A B) A) = emptybag
+   * - (difference_remove (intersection_min B A) A) = emptybag
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const;
+  /**
+   * rewrites for n include:
+   * - (bag.choose (mkBag x c)) = x where c is a constant > 0
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteChoose(const TNode& n) const;
+  /**
+   * rewrites for n include:
+   * - (bag.card (mkBag x c)) = c where c is a constant > 0
+   * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+   * - otherwise = n
+   */
+  BagsRewriteResponse rewriteCard(const TNode& n) const;
+
+  /**
+   * rewrites for n include:
+   * - (bag.is_singleton (mkBag x c)) = (c == 1)
+   */
+  BagsRewriteResponse rewriteIsSingleton(const TNode& n) const;
+
+ private:
+  /** Reference to the rewriter statistics. */
+  NodeManager* d_nm;
+  /** Reference to the rewriter statistics. */
+  HistogramStat<Rewrite>* d_statistics;
+}; /* class TheoryBagsRewriter */
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */
diff --git a/src/theory/bags/bags_statistics.cpp b/src/theory/bags/bags_statistics.cpp
new file mode 100644 (file)
index 0000000..ea3d304
--- /dev/null
@@ -0,0 +1,35 @@
+/*********************                                                        */
+/*! \file bags_statistics.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Statistics for the theory of bags
+ **/
+
+#include "theory/bags/bags_statistics.h"
+
+#include "smt/smt_statistics_registry.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+BagsStatistics::BagsStatistics() : d_rewrites("theory::bags::rewrites")
+{
+  smtStatisticsRegistry()->registerStat(&d_rewrites);
+}
+
+BagsStatistics::~BagsStatistics()
+{
+  smtStatisticsRegistry()->unregisterStat(&d_rewrites);
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/bags_statistics.h b/src/theory/bags/bags_statistics.h
new file mode 100644 (file)
index 0000000..457e3a3
--- /dev/null
@@ -0,0 +1,45 @@
+/*********************                                                        */
+/*! \file bags_statistics.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Statistics for the theory of bags
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS_STATISTICS_H
+#define CVC4__THEORY__BAGS_STATISTICS_H
+
+#include "expr/kind.h"
+#include "theory/bags/rewrites.h"
+#include "util/statistics_registry.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/**
+ * Statistics for the theory of bags.
+ */
+class BagsStatistics
+{
+ public:
+  BagsStatistics();
+  ~BagsStatistics();
+
+  /** Counts the number of applications of each type of rewrite rule */
+  HistogramStat<Rewrite> d_rewrites;
+};
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS_STATISTICS_H */
index 8093448a0466dd1b61cbcadc28fa8a8eaf6c6698..cdbef58dee64cb45b1fcf43903e4df99e33c9ab8 100644 (file)
@@ -8,8 +8,8 @@ theory THEORY_BAGS \
     ::CVC4::theory::bags::TheoryBags \
     "theory/bags/theory_bags.h"
 typechecker "theory/bags/theory_bags_type_rules.h"
-rewriter ::CVC4::theory::bags::TheoryBagsRewriter \
-    "theory/bags/theory_bags_rewriter.h"
+rewriter ::CVC4::theory::bags::BagsRewriter \
+    "theory/bags/bags_rewriter.h"
 
 properties parametric
 properties check propagate presolve
index d9248615b647491217d16b7abcde5257e47da0a9..facad3c927551e9b067e68020b55798e53bc8c0f 100644 (file)
@@ -22,6 +22,16 @@ bool NormalForm::checkNormalConstant(TNode n)
   return false;
 }
 
+bool NormalForm::AreChildrenConstants(TNode n)
+{
+  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
+}
+
+Node NormalForm::evaluate(TNode n)
+{
+  // TODO(projects#223): complete this function
+  return CVC4::Node();
+}
 }  // namespace bags
 }  // namespace theory
 }  // namespace CVC4
\ No newline at end of file
index 73fd8dba834c8caeb36b4d2878bbdfd67feadefb..8c719fe81c1838d7175326d329ce3525da563931 100644 (file)
  ** \brief Normal form for bag constants.
  **/
 
-#include "cvc4_private.h"
 #include <expr/node.h>
 
+#include "cvc4_private.h"
+
 #ifndef CVC4__THEORY__BAGS__NORMAL_FORM_H
 #define CVC4__THEORY__BAGS__NORMAL_FORM_H
 
@@ -36,6 +37,14 @@ class NormalForm
    * Also handles the corner cases of empty bag and singleton bag.
    */
   static bool checkNormalConstant(TNode n);
+  /**
+   * check whether all children of the given node are in normal form
+   */
+  static bool AreChildrenConstants(TNode n);
+  /**
+   * evaluate the node n to a constant value
+   */
+  static Node evaluate(TNode n);
 };
 }  // namespace bags
 }  // namespace theory
diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp
new file mode 100644 (file)
index 0000000..758f8a6
--- /dev/null
@@ -0,0 +1,76 @@
+/*********************                                                        */
+/*! \file rewrites.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of inference information utility.
+ **/
+
+#include "theory/bags/rewrites.h"
+
+#include <iostream>
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+const char* toString(Rewrite r)
+{
+  switch (r)
+  {
+    case Rewrite::NONE: return "NONE";
+    case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT";
+    case Rewrite::CARD_MK_BAG: return "CARD_MK_BAG";
+    case Rewrite::CHOOSE_MK_BAG: return "CHOOSE_MK_BAG";
+    case Rewrite::CONSTANT_EVALUATION: return "CONSTANT_EVALUATION";
+    case Rewrite::COUNT_EMPTY: return "COUNT_EMPTY";
+    case Rewrite::COUNT_MK_BAG: return "COUNT_MK_BAG";
+    case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES";
+    case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT";
+    case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT";
+    case Rewrite::INTERSECTION_SAME: return "INTERSECTION_SAME";
+    case Rewrite::INTERSECTION_SHARED_LEFT: return "INTERSECTION_SHARED_LEFT";
+    case Rewrite::INTERSECTION_SHARED_RIGHT: return "INTERSECTION_SHARED_RIGHT";
+    case Rewrite::IS_SINGLETON_MK_BAG: return "IS_SINGLETON_MK_BAG";
+    case Rewrite::MK_BAG_COUNT_NEGATIVE: return "MK_BAG_COUNT_NEGATIVE";
+    case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
+    case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
+    case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT";
+    case Rewrite::REMOVE_SAME: return "REMOVE_SAME";
+    case Rewrite::SUB_BAG: return "SUB_BAG";
+    case Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT:
+      return "SUBTRACT_DISJOINT_SHARED_LEFT";
+    case Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT:
+      return "SUBTRACT_DISJOINT_SHARED_RIGHT";
+    case Rewrite::SUBTRACT_FROM_UNION: return "SUBTRACT_FROM_UNION";
+    case Rewrite::SUBTRACT_MIN: return "SUBTRACT_MIN";
+    case Rewrite::SUBTRACT_RETURN_LEFT: return "SUBTRACT_RETURN_LEFT";
+    case Rewrite::SUBTRACT_SAME: return "SUBTRACT_SAME";
+    case Rewrite::UNION_DISJOINT_EMPTY_LEFT: return "UNION_DISJOINT_EMPTY_LEFT";
+    case Rewrite::UNION_DISJOINT_EMPTY_RIGHT:
+      return "UNION_DISJOINT_EMPTY_RIGHT";
+    case Rewrite::UNION_DISJOINT_MAX_MIN: return "UNION_DISJOINT_MAX_MIN";
+    case Rewrite::UNION_MAX_EMPTY: return "UNION_MAX_EMPTY";
+    case Rewrite::UNION_MAX_SAME_OR_EMPTY: return "UNION_MAX_SAME_OR_EMPTY";
+    case Rewrite::UNION_MAX_UNION_LEFT: return "UNION_MAX_UNION_LEFT";
+    case Rewrite::UNION_MAX_UNION_RIGHT: return "UNION_MAX_UNION_RIGHT";
+
+    default: return "?";
+  }
+}
+
+std::ostream& operator<<(std::ostream& out, Rewrite r)
+{
+  out << toString(r);
+  return out;
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h
new file mode 100644 (file)
index 0000000..13b0ff2
--- /dev/null
@@ -0,0 +1,91 @@
+/*********************                                                        */
+/*! \file rewrites.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Type for rewrites for bags.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS__REWRITES_H
+#define CVC4__THEORY__BAGS__REWRITES_H
+
+#include <iosfwd>
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/** Types of rewrites used by bags
+ *
+ * This rewrites are documented where they are used in the rewriter.
+ */
+enum class Rewrite : uint32_t
+{
+  NONE, // no rewrite happened
+  CARD_DISJOINT,
+  CARD_MK_BAG,
+  CHOOSE_MK_BAG,
+  CONSTANT_EVALUATION,
+  COUNT_EMPTY,
+  COUNT_MK_BAG,
+  IDENTICAL_NODES,
+  INTERSECTION_EMPTY_LEFT,
+  INTERSECTION_EMPTY_RIGHT,
+  INTERSECTION_SAME,
+  INTERSECTION_SHARED_LEFT,
+  INTERSECTION_SHARED_RIGHT,
+  IS_SINGLETON_MK_BAG,
+  MK_BAG_COUNT_NEGATIVE,
+  REMOVE_FROM_UNION,
+  REMOVE_MIN,
+  REMOVE_RETURN_LEFT,
+  REMOVE_SAME,
+  SUB_BAG,
+  SUBTRACT_DISJOINT_SHARED_LEFT,
+  SUBTRACT_DISJOINT_SHARED_RIGHT,
+  SUBTRACT_FROM_UNION,
+  SUBTRACT_MIN,
+  SUBTRACT_RETURN_LEFT,
+  SUBTRACT_SAME,
+  UNION_DISJOINT_EMPTY_LEFT,
+  UNION_DISJOINT_EMPTY_RIGHT,
+  UNION_DISJOINT_MAX_MIN,
+  UNION_MAX_EMPTY,
+  UNION_MAX_SAME_OR_EMPTY,
+  UNION_MAX_UNION_LEFT,
+  UNION_MAX_UNION_RIGHT
+};
+
+/**
+ * Converts an rewrite to a string. Note: This function is also used in
+ * `safe_print()`. Changing this functions name or signature will result in
+ * `safe_print()` printing "<unsupported>" instead of the proper strings for
+ * the enum values.
+ *
+ * @param r The rewrite
+ * @return The name of the rewrite
+ */
+const char* toString(Rewrite r);
+
+/**
+ * Writes an rewrite name to a stream.
+ *
+ * @param out The stream to write to
+ * @param r The rewrite to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, Rewrite r);
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__REWRITES_H */
index 5ddd173025ad5fcc2c6e6fc8e0e603bbe964b16d..e4cd64b48d7adde7e5550472dd921f55bc4197ad 100644 (file)
@@ -29,8 +29,9 @@ TheoryBags::TheoryBags(context::Context* c,
     : Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm),
       d_state(c, u, valuation),
       d_im(*this, d_state, pnm),
-      d_rewriter(),
-      d_notify(*this, d_im)
+      d_notify(*this, d_im),
+      d_statistics(),
+      d_rewriter(&d_statistics.d_rewrites)
 {
   // use the official theory state and inference manager objects
   d_theoryState = &d_state;
index 44f7ae1b0b519a3465748c995bee047b4ce34e2e..08bc5f33aa8c78e68145a592afa22f4e198a8d30 100644 (file)
 
 #include <memory>
 
+#include "theory/bags/bags_rewriter.h"
+#include "theory/bags/bags_statistics.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/solver_state.h"
-#include "theory/bags/theory_bags_rewriter.h"
 #include "theory/theory.h"
 #include "theory/theory_eq_notify.h"
 #include "theory/uf/equality_engine.h"
@@ -95,8 +96,10 @@ class TheoryBags : public Theory
   InferenceManager d_im;
   /** Instance of the above class */
   NotifyClass d_notify;
+  /** Statistics for the theory of bags. */
+  BagsStatistics d_statistics;
   /** The theory rewriter for this theory. */
-  TheoryBagsRewriter d_rewriter;
+  BagsRewriter d_rewriter;
 
   void eqNotifyNewClass(TNode t);
   void eqNotifyMerge(TNode t1, TNode t2);
diff --git a/src/theory/bags/theory_bags_rewriter.cpp b/src/theory/bags/theory_bags_rewriter.cpp
deleted file mode 100644 (file)
index aaf0ab9..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-/*********************                                                        */
-/*! \file theory_bags_rewriter.cpp
- ** \verbatim
- ** Top contributors (to current version):
- **   Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Bags theory rewriter.
- **/
-
-#include "theory/bags/theory_bags_rewriter.h"
-
-using namespace CVC4::kind;
-
-namespace CVC4 {
-namespace theory {
-namespace bags {
-
-RewriteResponse TheoryBagsRewriter::postRewrite(TNode node)
-{
-  // TODO(projects#225): complete the code here
-  return RewriteResponse(REWRITE_DONE, node);
-}
-
-RewriteResponse TheoryBagsRewriter::preRewrite(TNode node)
-{
-  // TODO(projects#225): complete the code here
-  return RewriteResponse(REWRITE_DONE, node);
-}
-
-}  // namespace bags
-}  // namespace theory
-}  // namespace CVC4
diff --git a/src/theory/bags/theory_bags_rewriter.h b/src/theory/bags/theory_bags_rewriter.h
deleted file mode 100644 (file)
index 7be8863..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-/*********************                                                        */
-/*! \file theory_bags_rewriter.h
- ** \verbatim
- ** Top contributors (to current version):
- **   Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Bags theory rewriter.
- **/
-
-#include "cvc4_private.h"
-
-#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
-#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
-
-#include "theory/rewriter.h"
-
-namespace CVC4 {
-namespace theory {
-namespace bags {
-
-class TheoryBagsRewriter : public TheoryRewriter
-{
- public:
-  RewriteResponse postRewrite(TNode node) override;
-
-  RewriteResponse preRewrite(TNode node) override;
-}; /* class TheoryBagsRewriter */
-
-}  // namespace bags
-}  // namespace theory
-}  // namespace CVC4
-
-#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */
index 26639afd851385c49a856c48463ec808001ab560..a1ba896c1d943e2bd4b89045e479b321702ad3ef 100644 (file)
@@ -66,7 +66,7 @@ class BagEnumerator : public TypeEnumeratorBase<BagEnumerator>
    *
    * This seems too expensive to implement.
    * For now we are implementing an obvious solution
-   * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both fininte and infinite
+   * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both finite and infinite
    * types
    */
   BagEnumerator& operator++() override;
index fc5a193488026d69e433688e2e57341d78e015e7..e4279479df0d4d047273e19b432ee41faa9dd486 100644 (file)
@@ -236,7 +236,7 @@ struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type)
   {
-    return Cardinality::UNKNOWN_CARD;
+    return Cardinality::INTEGERS;
   }
 
   static bool isWellFounded(TypeNode type) { return type[0].isWellFounded(); }
index f40d9658b35eaafec9715d8b4deeae0a305c2f59..108471d4ab722148e23edc82cd54975a0cd5474e 100644 (file)
@@ -14,6 +14,7 @@ cvc4_add_unit_test_white(evaluator_white theory)
 cvc4_add_unit_test_white(logic_info_white theory)
 cvc4_add_unit_test_white(sequences_rewriter_white theory)
 cvc4_add_unit_test_white(theory_arith_white theory)
+cvc4_add_unit_test_white(theory_bags_rewriter_black theory)
 cvc4_add_unit_test_white(theory_bags_type_rules_black theory)
 cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
 cvc4_add_unit_test_white(theory_bv_white theory)
diff --git a/test/unit/theory/theory_bags_rewriter_black.h b/test/unit/theory/theory_bags_rewriter_black.h
new file mode 100644 (file)
index 0000000..d518058
--- /dev/null
@@ -0,0 +1,593 @@
+/*********************                                                        */
+/*! \file theory_bags_rewriter_black.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Black box testing of bags rewriter
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleBlack : public CxxTest::TestSuite
+{
+ public:
+  void setUp() override
+  {
+    d_em.reset(new ExprManager());
+    d_smt.reset(new SmtEngine(d_em.get()));
+    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+    d_smt->finishInit();
+    d_rewriter.reset(new BagsRewriter(nullptr));
+  }
+
+  void tearDown() override
+  {
+    d_rewriter.reset();
+    d_smt.reset();
+    d_nm.release();
+    d_em.reset();
+  }
+
+  std::vector<Node> getNStrings(size_t n)
+  {
+    std::vector<Node> elements(n);
+    for (size_t i = 0; i < n; i++)
+    {
+      elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
+    }
+    return elements;
+  }
+
+  void testEmptyBagNormalForm()
+  {
+    Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+    // empty bags are in normal form
+    TS_ASSERT(emptybag.isConst());
+    RewriteResponse response = d_rewriter->postRewrite(emptybag);
+    TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
+  }
+
+  void testBagEquality()
+  {
+    vector<Node> elements = getNStrings(2);
+    Node x = elements[0];
+    Node y = elements[1];
+    Node c = d_nm->mkSkolem("c", d_nm->integerType());
+    Node d = d_nm->mkSkolem("d", d_nm->integerType());
+    Node bagX = d_nm->mkNode(MK_BAG, x, c);
+    Node bagY = d_nm->mkNode(MK_BAG, y, d);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+    // (= A A) = true where A is a bag
+    Node n1 = emptyBag.eqNode(emptyBag);
+    RewriteResponse response1 = d_rewriter->preRewrite(n1);
+    TS_ASSERT(response1.d_node == d_nm->mkConst(true)
+              && response1.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testMkBagConstantElement()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node negative =
+        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
+    Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
+    Node positive =
+        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+    // bags with non-positive multiplicity are rewritten as empty bags
+    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+              && negativeResponse.d_node == emptybag);
+    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+              && zeroResponse.d_node == emptybag);
+
+    // no change for positive
+    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+              && positive == positiveResponse.d_node);
+  }
+
+  void testMkBagVariableElement()
+  {
+    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+    Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
+    Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
+    Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0)));
+    Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1)));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+    // bags with non-positive multiplicity are rewritten as empty bags
+    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+              && negativeResponse.d_node == emptybag);
+    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+              && zeroResponse.d_node == emptybag);
+
+    // no change for positive
+    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+              && positive == positiveResponse.d_node);
+  }
+
+  void testBagCount()
+  {
+    int n = 3;
+    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+    Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
+    Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n)));
+
+    // (bag.count x emptybag) = 0
+    Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
+              && response1.d_node == d_nm->mkConst(Rational(0)));
+
+    // (bag.count x (mkBag x c) = c where c > 0 is a constant
+    Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
+              && response2.d_node == d_nm->mkConst(Rational(n)));
+  }
+
+  void testUnionMax()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+    // (union_max A emptybag) = A
+    Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max emptybag A) = A
+    Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
+    TS_ASSERT(response2.d_node == A
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A A) = A
+    Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
+    TS_ASSERT(response3.d_node == A
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_max A B)) = (union_max A B)
+    Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
+    RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
+    TS_ASSERT(response4.d_node == unionMaxAB
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_max B A)) = (union_max B A)
+    Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
+    RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
+    TS_ASSERT(response5.d_node == unionMaxBA
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_max A B) A) = (union_max A B)
+    Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
+    RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
+    TS_ASSERT(response6.d_node == unionMaxAB
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_max B A) A) = (union_max B A)
+    Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
+    RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
+    TS_ASSERT(response7.d_node == unionMaxBA
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+    Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
+    TS_ASSERT(response8.d_node == unionDisjointAB
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+    Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
+    TS_ASSERT(response9.d_node == unionDisjointBA
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_disjoint A B) A) = (union_disjoint A B)
+    Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
+    TS_ASSERT(response10.d_node == unionDisjointAB
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_disjoint B A) A) = (union_disjoint B A)
+    Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
+    TS_ASSERT(response11.d_node == unionDisjointBA
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testUnionDisjoint()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (union_disjoint A emptybag) = A
+    Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint emptybag A) = A
+    Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
+    TS_ASSERT(response2.d_node == A
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint (union_max A B) (intersection_min B A)) =
+    //          (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+    Node unionDisjoint3 =
+        d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
+    RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
+    TS_ASSERT(response3.d_node == unionDisjointAB
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint (intersection_min B A)) (union_max A B) =
+    //          (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+    Node unionDisjoint4 =
+        d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
+    RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
+    TS_ASSERT(response4.d_node == unionDisjointBA
+              && response4.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testIntersectionMin()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+    // (intersection_min A emptybag) = emptyBag
+    Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == emptyBag
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A A) = A
+    Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == A
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_max A B) = A
+    Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
+    RewriteResponse response4 = d_rewriter->postRewrite(n4);
+    TS_ASSERT(response4.d_node == A
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_max B A) = A
+    Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
+    RewriteResponse response5 = d_rewriter->postRewrite(n5);
+    TS_ASSERT(response5.d_node == A
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_max A B) A) = A
+    Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == A
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_max B A) A) = A
+    Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == A
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_disjoint A B) = A
+    Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == A
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_disjoint B A) = A
+    Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == A
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_disjoint A B) A) = A
+    Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == A
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_disjoint B A) A) = A
+    Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == A
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testDifferenceSubtract()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (difference_subtract A emptybag) = A
+    Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A A) = emptybag
+    Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == emptyBag
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (union_disjoint A B) A) = B
+    Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
+    RewriteResponse response4 = d_rewriter->postRewrite(n4);
+    TS_ASSERT(response4.d_node == B
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (union_disjoint B A) A) = B
+    Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
+    RewriteResponse response5 = d_rewriter->postRewrite(n5);
+    TS_ASSERT(response5.d_node == B
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_disjoint A B)) = emptybag
+    Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == emptyBag
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_disjoint B A)) = emptybag
+    Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == emptyBag
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_max A B)) = emptybag
+    Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == emptyBag
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_max B A)) = emptybag
+    Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == emptyBag
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (intersection_min A B) A) = emptybag
+    Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == emptyBag
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (intersection_min B A) A) = emptybag
+    Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == emptyBag
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testDifferenceRemove()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (difference_remove A emptybag) = A
+    Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A A) = emptybag
+    Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == emptyBag
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_disjoint A B)) = emptybag
+    Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == emptyBag
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_disjoint B A)) = emptybag
+    Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == emptyBag
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_max A B)) = emptybag
+    Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == emptyBag
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_max B A)) = emptybag
+    Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == emptyBag
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove (intersection_min A B) A) = emptybag
+    Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == emptyBag
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove (intersection_min B A) A) = emptybag
+    Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == emptyBag
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testChoose()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node c = d_nm->mkConst(Rational(3));
+    Node bag = d_nm->mkNode(MK_BAG, x, c);
+
+    // (bag.choose (mkBag x c)) = x where c is a constant > 0
+    Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == x
+              && response1.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testBagCard()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node zero = d_nm->mkConst(Rational(0));
+    Node c = d_nm->mkConst(Rational(3));
+    Node bag = d_nm->mkNode(MK_BAG, x, c);
+    vector<Node> elements = getNStrings(2);
+    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4)));
+    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5)));
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+    // TODO(projects#223): enable this test after implementing bags normal form
+    //    // (bag.card emptybag) = 0
+    //    Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
+    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    //    TS_ASSERT(response1.d_node == zero && response1.d_status ==
+    //    REWRITE_AGAIN_FULL);
+
+    // (bag.card (mkBag x c)) = c where c is a constant > 0
+    Node n2 = d_nm->mkNode(BAG_CARD, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == c
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+    Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
+    Node cardA = d_nm->mkNode(BAG_CARD, A);
+    Node cardB = d_nm->mkNode(BAG_CARD, B);
+    Node plus = d_nm->mkNode(PLUS, cardA, cardB);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == plus
+              && response3.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testIsSingleton()
+  {
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node c = d_nm->mkSkolem("c", d_nm->integerType());
+    Node bag = d_nm->mkNode(MK_BAG, x, c);
+
+    // TODO(projects#223): complete this function
+    // (bag.is_singleton emptybag) = false
+    //    Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
+    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    //    TS_ASSERT(response1.d_node == d_nm->mkConst(false)
+    //              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (bag.is_singleton (mkBag x c) = (c == 1)
+    Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    Node one = d_nm->mkConst(Rational(1));
+    Node equal = c.eqNode(one);
+    TS_ASSERT(response2.d_node == equal
+              && response2.d_status == REWRITE_AGAIN_FULL);
+  }
+
+ private:
+  std::unique_ptr<ExprManager> d_em;
+  std::unique_ptr<SmtEngine> d_smt;
+  std::unique_ptr<NodeManager> d_nm;
+
+  std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */