Add operator MakeBagOp for constructing bags (#5209)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Wed, 21 Oct 2020 13:19:55 +0000 (08:19 -0500)
committerGitHub <noreply@github.com>
Wed, 21 Oct 2020 13:19:55 +0000 (08:19 -0500)
This PR removes subtyping rules for bags and add operator MakeBagOp similar to SingletonOp

16 files changed:
src/CMakeLists.txt
src/api/cvc4cppkind.h
src/expr/node_manager.cpp
src/expr/node_manager.h
src/expr/type_node.cpp
src/theory/bags/bags_rewriter.cpp
src/theory/bags/kinds
src/theory/bags/make_bag_op.cpp [new file with mode: 0644]
src/theory/bags/make_bag_op.h [new file with mode: 0644]
src/theory/bags/theory_bags_type_enumerator.cpp
src/theory/bags/theory_bags_type_rules.h
test/unit/theory/CMakeLists.txt
test/unit/theory/theory_bags_rewriter_black.h [deleted file]
test/unit/theory/theory_bags_rewriter_white.h [new file with mode: 0644]
test/unit/theory/theory_bags_type_rules_black.h [deleted file]
test/unit/theory/theory_bags_type_rules_white.h [new file with mode: 0644]

index 4d96fa0b3f6640b3ec44ea3a8836c2e862b10e6f..5966debc19a2cd52c9d7125c401ace1516bb3b6d 100644 (file)
@@ -429,6 +429,8 @@ libcvc4_add_sources(
   theory/bags/bags_statistics.h
   theory/bags/inference_manager.cpp
   theory/bags/inference_manager.h
+  theory/bags/make_bag_op.cpp
+  theory/bags/make_bag_op.h
   theory/bags/normal_form.cpp
   theory/bags/normal_form.h
   theory/bags/rewrites.cpp
index 913a4a993519f6ba7a7f40130b5356b1329baf2f..d6ee24f1ebb55ccc20eb3dec784f2f1d35c7efcf 100644 (file)
@@ -1841,7 +1841,8 @@ enum CVC4_PUBLIC Kind : int32_t
    */
   MEMBER,
   /**
-   * The set of the single element given as a parameter.
+   * Construct a singleton set from an element given as a parameter.
+   * The returned set has same type of the element.
    * Parameters: 1
    *   -[1]: Single element
    * Create with:
index f8057006c809b609e31e8c8593275ba04beacb07..e9f121047b2a4c8e68bce7ede5cfadbfcce22123 100644 (file)
@@ -961,6 +961,17 @@ Node NodeManager::mkSingleton(const TypeNode& t, const TNode n)
   return singleton;
 }
 
+Node NodeManager::mkBag(const TypeNode& t, const TNode n, const TNode m)
+{
+  Assert(n.getType().isSubtypeOf(t))
+      << "Invalid operands for mkBag. The type '" << n.getType()
+      << "' of node '" << n << "' is not a subtype of '" << t << "'."
+      << std::endl;
+  Node op = mkConst(MakeBagOp(t));
+  Node bag = mkNode(kind::MK_BAG, op, n, m);
+  return bag;
+}
+
 Node NodeManager::mkAbstractValue(const TypeNode& type) {
   Node n = mkConst(AbstractValue(++d_abstractValueCount));
   n.setAttribute(TypeAttr(), type);
index 5427c3b6a5f4470ee4dfd479cb53c4e29f28a504..8f223752389ffa76a9e4f7a2dd8b0b7befda4e60 100644 (file)
@@ -578,12 +578,22 @@ class NodeManager {
   /**
   * Create a singleton set from the given element n.
   * @param t the element type of the returned set.
-  * Note that the type of n needs to be a subtype of t.
+  *          Note that the type of n needs to be a subtype of t.
   * @param n the single element in the singleton.
   * @return a singleton set constructed from the element n.
   */
   Node mkSingleton(const TypeNode& t, const TNode n);
 
+  /**
+  * Create a bag from the given element n along with its multiplicity m.
+  * @param t the element type of the returned bag.
+  *          Note that the type of n needs to be a subtype of t.
+  * @param n the element that is used to to construct the bag
+  * @param m the multiplicity of the element n
+  * @return a bag that contains m occurrences of n.
+  */
+  Node mkBag(const TypeNode& t, const TNode n, const TNode m);
+
   /**
    * Create a constant of type T.  It will have the appropriate
    * CONST_* kind defined for T.
index 659b1eef22da2ea2452bd139b52f44693643a91f..e917a9d0d75261a81680a2b0613f10201a37fd73 100644 (file)
@@ -574,11 +574,12 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) {
     case kind::ARRAY_TYPE:
     case kind::DATATYPE_TYPE:
     case kind::PARAMETRIC_DATATYPE:
-    case kind::SEQUENCE_TYPE: return TypeNode();
+    case kind::SEQUENCE_TYPE:
     case kind::SET_TYPE:
+    case kind::BAG_TYPE:
     {
-      // we don't support subtyping for sets
-      return TypeNode(); // return null type
+      // we don't support subtyping except for built in types Int and Real.
+      return TypeNode();  // return null type
     }
     case kind::SEXPR_TYPE:
       Unimplemented()
index 1faaf55c0194093decf87d2f48896326d239d09d..c413a5e7e1eabfadefc565f644522a2eb17de3ad 100644 (file)
@@ -438,7 +438,8 @@ BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
   {
     // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
     Node one = d_nm->mkConst(Rational(1));
-    Node bag = d_nm->mkNode(MK_BAG, n[0][0], one);
+    TypeNode type = n[0].getType().getSetElementType();
+    Node bag = d_nm->mkBag(type, n[0][0], one);
     return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
index 72326de080efa197f3182b7a3ca9137d8db33f66..86e89e0bd274b431e2b5a29b0ed5485d962f6cc0 100644 (file)
@@ -47,7 +47,14 @@ operator DIFFERENCE_REMOVE 2  "bag difference remove (removes shared elements)"
 
 operator BAG_IS_INCLUDED   2  "inclusion predicate for bags (less than or equal multiplicities)"
 operator BAG_COUNT         2  "multiplicity of an element in a bag"
-operator MK_BAG            2  "constructs a bag from one element along with its multiplicity"
+
+constant MK_BAG_OP \
+       ::CVC4::MakeBagOp \
+       ::CVC4::MakeBagOpHashFunction \
+       "theory/bags/make_bag_op.h" \
+       "operator for MK_BAG; payload is an instance of the CVC4::MakeBagOp class"
+parameterized MK_BAG MK_BAG_OP 2 \
+"constructs a bag from one element along with its multiplicity"
 
 # The operator bag-is-singleton returns whether the given bag is a singleton
 operator BAG_IS_SINGLETON  1  "return whether the given bag is a singleton"
@@ -69,6 +76,7 @@ typerule DIFFERENCE_SUBTRACT ::CVC4::theory::bags::BinaryOperatorTypeRule
 typerule DIFFERENCE_REMOVE   ::CVC4::theory::bags::BinaryOperatorTypeRule
 typerule BAG_IS_INCLUDED     ::CVC4::theory::bags::IsIncludedTypeRule
 typerule BAG_COUNT           ::CVC4::theory::bags::CountTypeRule
+typerule MK_BAG_OP           "SimpleTypeRule<RBuiltinOperator>"
 typerule MK_BAG              ::CVC4::theory::bags::MkBagTypeRule
 typerule EMPTYBAG            ::CVC4::theory::bags::EmptyBagTypeRule
 typerule BAG_CARD            ::CVC4::theory::bags::CardTypeRule
diff --git a/src/theory/bags/make_bag_op.cpp b/src/theory/bags/make_bag_op.cpp
new file mode 100644 (file)
index 0000000..6a535af
--- /dev/null
@@ -0,0 +1,48 @@
+/*********************                                                        */
+/*! \file bag_op.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief a class for MK_BAG operator
+ **/
+
+#include <iostream>
+
+#include "expr/type_node.h"
+#include "make_bag_op.h"
+
+namespace CVC4 {
+
+std::ostream& operator<<(std::ostream& out, const MakeBagOp& op)
+{
+  return out << "(mkBag_op " << op.getType() << ')';
+}
+
+size_t MakeBagOpHashFunction::operator()(const MakeBagOp& op) const
+{
+  return TypeNodeHashFunction()(op.getType());
+}
+
+MakeBagOp::MakeBagOp(const TypeNode& elementType)
+    : d_type(new TypeNode(elementType))
+{
+}
+
+MakeBagOp::MakeBagOp(const MakeBagOp& op) : d_type(new TypeNode(op.getType()))
+{
+}
+
+const TypeNode& MakeBagOp::getType() const { return *d_type; }
+
+bool MakeBagOp::operator==(const MakeBagOp& op) const
+{
+  return getType() == op.getType();
+}
+
+}  // namespace CVC4
diff --git a/src/theory/bags/make_bag_op.h b/src/theory/bags/make_bag_op.h
new file mode 100644 (file)
index 0000000..b479308
--- /dev/null
@@ -0,0 +1,63 @@
+/*********************                                                        */
+/*! \file mk_bag_op.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief a class for MK_BAG operator
+ **/
+
+#include "cvc4_public.h"
+
+#ifndef CVC4__MAKE_BAG_OP_H
+#define CVC4__MAKE_BAG_OP_H
+
+#include <memory>
+
+namespace CVC4 {
+
+class TypeNode;
+
+/**
+ * The class is an operator for kind MK_BAG used to construct bags.
+ * It specifies the type of the element especially when it is a constant.
+ * e.g. the type of rational 1 is Int, however
+ * (mkBag (mkBag_op Real) 1) is of type (Bag Real), not (Bag Int).
+ * Note that the type passed to the constructor is the element's type, not the
+ * bag type.
+ */
+class MakeBagOp
+{
+ public:
+  MakeBagOp(const TypeNode& elementType);
+  MakeBagOp(const MakeBagOp& op);
+
+  /** return the type of the current object */
+  const TypeNode& getType() const;
+
+  bool operator==(const MakeBagOp& op) const;
+
+ private:
+  MakeBagOp();
+  /** a pointer to the type of the bag element */
+  std::unique_ptr<TypeNode> d_type;
+}; /* class MakeBagOp */
+
+std::ostream& operator<<(std::ostream& out, const MakeBagOp& op);
+
+/**
+ * Hash function for the MakeBagOpHashFunction objects.
+ */
+struct CVC4_PUBLIC MakeBagOpHashFunction
+{
+  size_t operator()(const MakeBagOp& op) const;
+}; /* struct MakeBagOpHashFunction */
+
+}  // namespace CVC4
+
+#endif /* CVC4__MAKE_BAG_OP_H */
index 7975bb379da166100d9f4a9832eddbf840eecd98..727407937cd40f13cba30a85c98d7ef47e1b1ed6 100644 (file)
@@ -54,7 +54,8 @@ BagEnumerator& BagEnumerator::operator++()
 {
   // increase the multiplicity by one
   Node one = d_nodeManager->mkConst(Rational(1));
-  Node singleton = d_nodeManager->mkNode(kind::MK_BAG, d_element, one);
+  TypeNode elementType = d_elementTypeEnumerator.getType();
+  Node singleton = d_nodeManager->mkBag(elementType, d_element, one);
   if (d_currentBag.getKind() == kind::EMPTYBAG)
   {
     d_currentBag = singleton;
index 67293e2224e92045ed9149ce143265351b05a166..75f57ec885ddc007716950e9114a0e51d0b7576a 100644 (file)
@@ -42,19 +42,11 @@ struct BinaryOperatorTypeRule
       TypeNode secondBagType = n[1].getType(check);
       if (secondBagType != bagType)
       {
-        if (n.getKind() == kind::INTERSECTION_MIN)
-        {
-          bagType = TypeNode::mostCommonTypeNode(secondBagType, bagType);
-        }
-        else
-        {
-          bagType = TypeNode::leastCommonTypeNode(secondBagType, bagType);
-        }
-        if (bagType.isNull())
-        {
-          throw TypeCheckingExceptionPrivate(
-              n, "operator expects two bags of comparable types");
-        }
+        std::stringstream ss;
+        ss << "Operator " << n.getKind()
+           << " expects two bags of the same type. Found types '" << bagType
+           << "' and '" << secondBagType << "'.";
+        throw TypeCheckingExceptionPrivate(n, ss.str());
       }
     }
     return bagType;
@@ -110,15 +102,9 @@ struct CountTypeRule
             n, "checking for membership in a non-bag");
       }
       TypeNode elementType = n[0].getType(check);
-      // TODO(projects#226): comments from sets
-      //
-      // T : (Bag Int)
-      // B : (Bag Real)
-      // (= (as T (Bag Real)) B)
-      // (= (bag-count 0.5 B) 1)
-      // ...where (bag-count 0.5 T) is inferred
-
-      if (!elementType.isComparableTo(bagType.getBagElementType()))
+      // e.g. (count 1 (mkBag (mkBag_op Real) 1.0 3))) is 3 whereas
+      // (count 1.0 (mkBag (mkBag_op Int) 1 3))) throws a typing error
+      if (!elementType.isSubtypeOf(bagType.getBagElementType()))
       {
         std::stringstream ss;
         ss << "member operating on bags of different types:\n"
@@ -136,7 +122,10 @@ struct MkBagTypeRule
 {
   static TypeNode computeType(NodeManager* nm, TNode n, bool check)
   {
-    Assert(n.getKind() == kind::MK_BAG);
+    Assert(n.getKind() == kind::MK_BAG && n.hasOperator()
+           && n.getOperator().getKind() == kind::MK_BAG_OP);
+    MakeBagOp op = n.getOperator().getConst<MakeBagOp>();
+    TypeNode expectedElementType = op.getType();
     if (check)
     {
       if (n.getNumChildren() != 2)
@@ -153,9 +142,21 @@ struct MkBagTypeRule
         ss << "MK_BAG expects an integer for " << n[1] << ". Found" << type1;
         throw TypeCheckingExceptionPrivate(n, ss.str());
       }
+
+      TypeNode actualElementType = n[0].getType(check);
+      // the type of the element should be a subtype of the type of the operator
+      // e.g. (mkBag (mkBag_op Real) 1 1) where 1 is an Int
+      if (!actualElementType.isSubtypeOf(expectedElementType))
+      {
+        std::stringstream ss;
+        ss << "The type '" << actualElementType
+           << "' of the element is not a subtype of '" << expectedElementType
+           << "' in term : " << n;
+        throw TypeCheckingExceptionPrivate(n, ss.str());
+      }
     }
 
-    return nm->mkBagType(n[0].getType(check));
+    return nm->mkBagType(expectedElementType);
   }
 
   static bool computeIsConst(NodeManager* nodeManager, TNode n)
index e541a24fbfaa120a4c4c7986ff6f10a9d5779a3a..481c80f264812ed4b06a12c60f8a45f32e55a3e6 100644 (file)
@@ -14,8 +14,8 @@ cvc4_add_unit_test_white(evaluator_white theory)
 cvc4_add_unit_test_white(logic_info_white theory)
 cvc4_add_unit_test_white(sequences_rewriter_white theory)
 cvc4_add_unit_test_white(theory_arith_white theory)
-cvc4_add_unit_test_white(theory_bags_rewriter_black theory)
-cvc4_add_unit_test_white(theory_bags_type_rules_black theory)
+cvc4_add_unit_test_white(theory_bags_rewriter_white theory)
+cvc4_add_unit_test_white(theory_bags_type_rules_white theory)
 cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
 cvc4_add_unit_test_white(theory_bv_white theory)
 cvc4_add_unit_test_white(theory_engine_white theory)
diff --git a/test/unit/theory/theory_bags_rewriter_black.h b/test/unit/theory/theory_bags_rewriter_black.h
deleted file mode 100644 (file)
index 98f56fd..0000000
+++ /dev/null
@@ -1,620 +0,0 @@
-/*********************                                                        */
-/*! \file theory_bags_rewriter_black.h
- ** \verbatim
- ** Top contributors (to current version):
- **   Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Black box testing of bags rewriter
- **/
-
-#include <cxxtest/TestSuite.h>
-
-#include "expr/dtype.h"
-#include "smt/smt_engine.h"
-#include "theory/bags/bags_rewriter.h"
-#include "theory/strings/type_enumerator.h"
-
-using namespace CVC4;
-using namespace CVC4::smt;
-using namespace CVC4::theory;
-using namespace CVC4::kind;
-using namespace CVC4::theory::bags;
-using namespace std;
-
-typedef expr::Attribute<Node, Node> attribute;
-
-class BagsTypeRuleBlack : public CxxTest::TestSuite
-{
- public:
-  void setUp() override
-  {
-    d_em.reset(new ExprManager());
-    d_smt.reset(new SmtEngine(d_em.get()));
-    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
-    d_smt->finishInit();
-    d_rewriter.reset(new BagsRewriter(nullptr));
-  }
-
-  void tearDown() override
-  {
-    d_rewriter.reset();
-    d_smt.reset();
-    d_nm.release();
-    d_em.reset();
-  }
-
-  std::vector<Node> getNStrings(size_t n)
-  {
-    std::vector<Node> elements(n);
-    for (size_t i = 0; i < n; i++)
-    {
-      elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
-    }
-    return elements;
-  }
-
-  void testEmptyBagNormalForm()
-  {
-    Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
-    // empty bags are in normal form
-    TS_ASSERT(emptybag.isConst());
-    RewriteResponse response = d_rewriter->postRewrite(emptybag);
-    TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
-  }
-
-  void testBagEquality()
-  {
-    vector<Node> elements = getNStrings(2);
-    Node x = elements[0];
-    Node y = elements[1];
-    Node c = d_nm->mkSkolem("c", d_nm->integerType());
-    Node d = d_nm->mkSkolem("d", d_nm->integerType());
-    Node bagX = d_nm->mkNode(MK_BAG, x, c);
-    Node bagY = d_nm->mkNode(MK_BAG, y, d);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-
-    // (= A A) = true where A is a bag
-    Node n1 = emptyBag.eqNode(emptyBag);
-    RewriteResponse response1 = d_rewriter->preRewrite(n1);
-    TS_ASSERT(response1.d_node == d_nm->mkConst(true)
-              && response1.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testMkBagConstantElement()
-  {
-    vector<Node> elements = getNStrings(1);
-    Node negative =
-        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
-    Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
-    Node positive =
-        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
-    Node emptybag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
-    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
-    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
-
-    // bags with non-positive multiplicity are rewritten as empty bags
-    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
-              && negativeResponse.d_node == emptybag);
-    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
-              && zeroResponse.d_node == emptybag);
-
-    // no change for positive
-    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
-              && positive == positiveResponse.d_node);
-  }
-
-  void testMkBagVariableElement()
-  {
-    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
-    Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
-    Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
-    Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0)));
-    Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1)));
-    Node emptybag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
-    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
-    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
-
-    // bags with non-positive multiplicity are rewritten as empty bags
-    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
-              && negativeResponse.d_node == emptybag);
-    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
-              && zeroResponse.d_node == emptybag);
-
-    // no change for positive
-    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
-              && positive == positiveResponse.d_node);
-  }
-
-  void testBagCount()
-  {
-    int n = 3;
-    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
-    Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
-    Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n)));
-
-    // (bag.count x emptybag) = 0
-    Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
-              && response1.d_node == d_nm->mkConst(Rational(0)));
-
-    // (bag.count x (mkBag x c) = c where c > 0 is a constant
-    Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
-              && response2.d_node == d_nm->mkConst(Rational(n)));
-  }
-
-  void testUnionMax()
-  {
-    int n = 3;
-    vector<Node> elements = getNStrings(2);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
-    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
-    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-
-    // (union_max A emptybag) = A
-    Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
-    TS_ASSERT(response1.d_node == A
-              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max emptybag A) = A
-    Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
-    RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
-    TS_ASSERT(response2.d_node == A
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max A A) = A
-    Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
-    RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
-    TS_ASSERT(response3.d_node == A
-              && response3.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max A (union_max A B)) = (union_max A B)
-    Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
-    RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
-    TS_ASSERT(response4.d_node == unionMaxAB
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max A (union_max B A)) = (union_max B A)
-    Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
-    RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
-    TS_ASSERT(response5.d_node == unionMaxBA
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max (union_max A B) A) = (union_max A B)
-    Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
-    RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
-    TS_ASSERT(response6.d_node == unionMaxAB
-              && response6.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max (union_max B A) A) = (union_max B A)
-    Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
-    RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
-    TS_ASSERT(response7.d_node == unionMaxBA
-              && response7.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
-    Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
-    RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
-    TS_ASSERT(response8.d_node == unionDisjointAB
-              && response8.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
-    Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
-    RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
-    TS_ASSERT(response9.d_node == unionDisjointBA
-              && response9.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max (union_disjoint A B) A) = (union_disjoint A B)
-    Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
-    RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
-    TS_ASSERT(response10.d_node == unionDisjointAB
-              && response10.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_max (union_disjoint B A) A) = (union_disjoint B A)
-    Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
-    RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
-    TS_ASSERT(response11.d_node == unionDisjointBA
-              && response11.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testUnionDisjoint()
-  {
-    int n = 3;
-    vector<Node> elements = getNStrings(2);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
-    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
-    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
-    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
-    // (union_disjoint A emptybag) = A
-    Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
-    TS_ASSERT(response1.d_node == A
-              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_disjoint emptybag A) = A
-    Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
-    RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
-    TS_ASSERT(response2.d_node == A
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_disjoint (union_max A B) (intersection_min B A)) =
-    //          (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
-    Node unionDisjoint3 =
-        d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
-    RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
-    TS_ASSERT(response3.d_node == unionDisjointAB
-              && response3.d_status == REWRITE_AGAIN_FULL);
-
-    // (union_disjoint (intersection_min B A)) (union_max A B) =
-    //          (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
-    Node unionDisjoint4 =
-        d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
-    RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
-    TS_ASSERT(response4.d_node == unionDisjointBA
-              && response4.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testIntersectionMin()
-  {
-    int n = 3;
-    vector<Node> elements = getNStrings(2);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
-    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
-    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-
-    // (intersection_min A emptybag) = emptyBag
-    Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    TS_ASSERT(response1.d_node == emptyBag
-              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min emptybag A) = emptyBag
-    Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    TS_ASSERT(response2.d_node == emptyBag
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min A A) = A
-    Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
-    RewriteResponse response3 = d_rewriter->postRewrite(n3);
-    TS_ASSERT(response3.d_node == A
-              && response3.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min A (union_max A B) = A
-    Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
-    RewriteResponse response4 = d_rewriter->postRewrite(n4);
-    TS_ASSERT(response4.d_node == A
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min A (union_max B A) = A
-    Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
-    RewriteResponse response5 = d_rewriter->postRewrite(n5);
-    TS_ASSERT(response5.d_node == A
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min (union_max A B) A) = A
-    Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
-    RewriteResponse response6 = d_rewriter->postRewrite(n6);
-    TS_ASSERT(response6.d_node == A
-              && response6.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min (union_max B A) A) = A
-    Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
-    RewriteResponse response7 = d_rewriter->postRewrite(n7);
-    TS_ASSERT(response7.d_node == A
-              && response7.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min A (union_disjoint A B) = A
-    Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
-    RewriteResponse response8 = d_rewriter->postRewrite(n8);
-    TS_ASSERT(response8.d_node == A
-              && response8.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min A (union_disjoint B A) = A
-    Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
-    RewriteResponse response9 = d_rewriter->postRewrite(n9);
-    TS_ASSERT(response9.d_node == A
-              && response9.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min (union_disjoint A B) A) = A
-    Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
-    RewriteResponse response10 = d_rewriter->postRewrite(n10);
-    TS_ASSERT(response10.d_node == A
-              && response10.d_status == REWRITE_AGAIN_FULL);
-
-    // (intersection_min (union_disjoint B A) A) = A
-    Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
-    RewriteResponse response11 = d_rewriter->postRewrite(n11);
-    TS_ASSERT(response11.d_node == A
-              && response11.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testDifferenceSubtract()
-  {
-    int n = 3;
-    vector<Node> elements = getNStrings(2);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
-    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
-    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
-    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
-    // (difference_subtract A emptybag) = A
-    Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    TS_ASSERT(response1.d_node == A
-              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract emptybag A) = emptyBag
-    Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    TS_ASSERT(response2.d_node == emptyBag
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract A A) = emptybag
-    Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
-    RewriteResponse response3 = d_rewriter->postRewrite(n3);
-    TS_ASSERT(response3.d_node == emptyBag
-              && response3.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract (union_disjoint A B) A) = B
-    Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
-    RewriteResponse response4 = d_rewriter->postRewrite(n4);
-    TS_ASSERT(response4.d_node == B
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract (union_disjoint B A) A) = B
-    Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
-    RewriteResponse response5 = d_rewriter->postRewrite(n5);
-    TS_ASSERT(response5.d_node == B
-              && response4.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract A (union_disjoint A B)) = emptybag
-    Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
-    RewriteResponse response6 = d_rewriter->postRewrite(n6);
-    TS_ASSERT(response6.d_node == emptyBag
-              && response6.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract A (union_disjoint B A)) = emptybag
-    Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
-    RewriteResponse response7 = d_rewriter->postRewrite(n7);
-    TS_ASSERT(response7.d_node == emptyBag
-              && response7.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract A (union_max A B)) = emptybag
-    Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
-    RewriteResponse response8 = d_rewriter->postRewrite(n8);
-    TS_ASSERT(response8.d_node == emptyBag
-              && response8.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract A (union_max B A)) = emptybag
-    Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
-    RewriteResponse response9 = d_rewriter->postRewrite(n9);
-    TS_ASSERT(response9.d_node == emptyBag
-              && response9.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract (intersection_min A B) A) = emptybag
-    Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
-    RewriteResponse response10 = d_rewriter->postRewrite(n10);
-    TS_ASSERT(response10.d_node == emptyBag
-              && response10.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_subtract (intersection_min B A) A) = emptybag
-    Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
-    RewriteResponse response11 = d_rewriter->postRewrite(n11);
-    TS_ASSERT(response11.d_node == emptyBag
-              && response11.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testDifferenceRemove()
-  {
-    int n = 3;
-    vector<Node> elements = getNStrings(2);
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
-    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
-    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
-    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
-    // (difference_remove A emptybag) = A
-    Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
-    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    TS_ASSERT(response1.d_node == A
-              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove emptybag A) = emptyBag
-    Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    TS_ASSERT(response2.d_node == emptyBag
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove A A) = emptybag
-    Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
-    RewriteResponse response3 = d_rewriter->postRewrite(n3);
-    TS_ASSERT(response3.d_node == emptyBag
-              && response3.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove A (union_disjoint A B)) = emptybag
-    Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
-    RewriteResponse response6 = d_rewriter->postRewrite(n6);
-    TS_ASSERT(response6.d_node == emptyBag
-              && response6.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove A (union_disjoint B A)) = emptybag
-    Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
-    RewriteResponse response7 = d_rewriter->postRewrite(n7);
-    TS_ASSERT(response7.d_node == emptyBag
-              && response7.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove A (union_max A B)) = emptybag
-    Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
-    RewriteResponse response8 = d_rewriter->postRewrite(n8);
-    TS_ASSERT(response8.d_node == emptyBag
-              && response8.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove A (union_max B A)) = emptybag
-    Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
-    RewriteResponse response9 = d_rewriter->postRewrite(n9);
-    TS_ASSERT(response9.d_node == emptyBag
-              && response9.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove (intersection_min A B) A) = emptybag
-    Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
-    RewriteResponse response10 = d_rewriter->postRewrite(n10);
-    TS_ASSERT(response10.d_node == emptyBag
-              && response10.d_status == REWRITE_AGAIN_FULL);
-
-    // (difference_remove (intersection_min B A) A) = emptybag
-    Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
-    RewriteResponse response11 = d_rewriter->postRewrite(n11);
-    TS_ASSERT(response11.d_node == emptyBag
-              && response11.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testChoose()
-  {
-    Node x = d_nm->mkSkolem("x", d_nm->stringType());
-    Node c = d_nm->mkConst(Rational(3));
-    Node bag = d_nm->mkNode(MK_BAG, x, c);
-
-    // (bag.choose (mkBag x c)) = x where c is a constant > 0
-    Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
-    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    TS_ASSERT(response1.d_node == x
-              && response1.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testBagCard()
-  {
-    Node x = d_nm->mkSkolem("x", d_nm->stringType());
-    Node emptyBag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node zero = d_nm->mkConst(Rational(0));
-    Node c = d_nm->mkConst(Rational(3));
-    Node bag = d_nm->mkNode(MK_BAG, x, c);
-    vector<Node> elements = getNStrings(2);
-    Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4)));
-    Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5)));
-    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-
-    // TODO(projects#223): enable this test after implementing bags normal form
-    //    // (bag.card emptybag) = 0
-    //    Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
-    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    //    TS_ASSERT(response1.d_node == zero && response1.d_status ==
-    //    REWRITE_AGAIN_FULL);
-
-    // (bag.card (mkBag x c)) = c where c is a constant > 0
-    Node n2 = d_nm->mkNode(BAG_CARD, bag);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    TS_ASSERT(response2.d_node == c
-              && response2.d_status == REWRITE_AGAIN_FULL);
-
-    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
-    Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
-    Node cardA = d_nm->mkNode(BAG_CARD, A);
-    Node cardB = d_nm->mkNode(BAG_CARD, B);
-    Node plus = d_nm->mkNode(PLUS, cardA, cardB);
-    RewriteResponse response3 = d_rewriter->postRewrite(n3);
-    TS_ASSERT(response3.d_node == plus
-              && response3.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testIsSingleton()
-  {
-    Node emptybag =
-        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-    Node x = d_nm->mkSkolem("x", d_nm->stringType());
-    Node c = d_nm->mkSkolem("c", d_nm->integerType());
-    Node bag = d_nm->mkNode(MK_BAG, x, c);
-
-    // TODO(projects#223): complete this function
-    // (bag.is_singleton emptybag) = false
-    //    Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
-    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
-    //    TS_ASSERT(response1.d_node == d_nm->mkConst(false)
-    //              && response1.d_status == REWRITE_AGAIN_FULL);
-
-    // (bag.is_singleton (mkBag x c) = (c == 1)
-    Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
-    RewriteResponse response2 = d_rewriter->postRewrite(n2);
-    Node one = d_nm->mkConst(Rational(1));
-    Node equal = c.eqNode(one);
-    TS_ASSERT(response2.d_node == equal
-              && response2.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testFromSet()
-  {
-    Node x = d_nm->mkSkolem("x", d_nm->stringType());
-    Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
-
-    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
-    Node n = d_nm->mkNode(BAG_FROM_SET, singleton);
-    RewriteResponse response = d_rewriter->postRewrite(n);
-    Node one = d_nm->mkConst(Rational(1));
-    Node bag = d_nm->mkNode(MK_BAG, x, one);
-    TS_ASSERT(response.d_node == bag
-              && response.d_status == REWRITE_AGAIN_FULL);
-  }
-
-  void testToSet()
-  {
-    Node x = d_nm->mkSkolem("x", d_nm->stringType());
-    Node bag = d_nm->mkNode(MK_BAG, x, d_nm->mkConst(Rational(5)));
-
-    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
-    Node n = d_nm->mkNode(BAG_TO_SET, bag);
-    RewriteResponse response = d_rewriter->postRewrite(n);
-    Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
-    TS_ASSERT(response.d_node == singleton
-              && response.d_status == REWRITE_AGAIN_FULL);
-  }
-
- private:
-  std::unique_ptr<ExprManager> d_em;
-  std::unique_ptr<SmtEngine> d_smt;
-  std::unique_ptr<NodeManager> d_nm;
-
-  std::unique_ptr<BagsRewriter> d_rewriter;
-}; /* class BagsTypeRuleBlack */
diff --git a/test/unit/theory/theory_bags_rewriter_white.h b/test/unit/theory/theory_bags_rewriter_white.h
new file mode 100644 (file)
index 0000000..b1c75fd
--- /dev/null
@@ -0,0 +1,638 @@
+/*********************                                                        */
+/*! \file theory_bags_rewriter_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief White box testing of bags rewriter
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleWhite : public CxxTest::TestSuite
+{
+ public:
+  void setUp() override
+  {
+    d_em.reset(new ExprManager());
+    d_smt.reset(new SmtEngine(d_em.get()));
+    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+    d_smt->finishInit();
+    d_rewriter.reset(new BagsRewriter(nullptr));
+  }
+
+  void tearDown() override
+  {
+    d_rewriter.reset();
+    d_smt.reset();
+    d_nm.release();
+    d_em.reset();
+  }
+
+  std::vector<Node> getNStrings(size_t n)
+  {
+    std::vector<Node> elements(n);
+    for (size_t i = 0; i < n; i++)
+    {
+      elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
+    }
+    return elements;
+  }
+
+  void testEmptyBagNormalForm()
+  {
+    Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+    // empty bags are in normal form
+    TS_ASSERT(emptybag.isConst());
+    RewriteResponse response = d_rewriter->postRewrite(emptybag);
+    TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
+  }
+
+  void testBagEquality()
+  {
+    vector<Node> elements = getNStrings(2);
+    Node x = elements[0];
+    Node y = elements[1];
+    Node c = d_nm->mkSkolem("c", d_nm->integerType());
+    Node d = d_nm->mkSkolem("d", d_nm->integerType());
+    Node bagX = d_nm->mkBag(d_nm->stringType(), x, c);
+    Node bagY = d_nm->mkBag(d_nm->stringType(), y, d);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+    // (= A A) = true where A is a bag
+    Node n1 = emptyBag.eqNode(emptyBag);
+    RewriteResponse response1 = d_rewriter->preRewrite(n1);
+    TS_ASSERT(response1.d_node == d_nm->mkConst(true)
+              && response1.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testMkBagConstantElement()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node negative = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+    Node zero = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+    Node positive = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+    // bags with non-positive multiplicity are rewritten as empty bags
+    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+              && negativeResponse.d_node == emptybag);
+    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+              && zeroResponse.d_node == emptybag);
+
+    // no change for positive
+    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+              && positive == positiveResponse.d_node);
+  }
+
+  void testMkBagVariableElement()
+  {
+    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+    Node variable =
+        d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1)));
+    Node negative =
+        d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1)));
+    Node zero =
+        d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(0)));
+    Node positive =
+        d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(1)));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+    RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+    RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+    // bags with non-positive multiplicity are rewritten as empty bags
+    TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+              && negativeResponse.d_node == emptybag);
+    TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+              && zeroResponse.d_node == emptybag);
+
+    // no change for positive
+    TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+              && positive == positiveResponse.d_node);
+  }
+
+  void testBagCount()
+  {
+    int n = 3;
+    Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+    Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
+    Node bag =
+        d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(n)));
+
+    // (bag.count x emptybag) = 0
+    Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
+              && response1.d_node == d_nm->mkConst(Rational(0)));
+
+    // (bag.count x (mkBag x c) = c where c > 0 is a constant
+    Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
+              && response2.d_node == d_nm->mkConst(Rational(n)));
+  }
+
+  void testUnionMax()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+    // (union_max A emptybag) = A
+    Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max emptybag A) = A
+    Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
+    TS_ASSERT(response2.d_node == A
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A A) = A
+    Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
+    TS_ASSERT(response3.d_node == A
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_max A B)) = (union_max A B)
+    Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
+    RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
+    TS_ASSERT(response4.d_node == unionMaxAB
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_max B A)) = (union_max B A)
+    Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
+    RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
+    TS_ASSERT(response5.d_node == unionMaxBA
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_max A B) A) = (union_max A B)
+    Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
+    RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
+    TS_ASSERT(response6.d_node == unionMaxAB
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_max B A) A) = (union_max B A)
+    Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
+    RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
+    TS_ASSERT(response7.d_node == unionMaxBA
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+    Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
+    TS_ASSERT(response8.d_node == unionDisjointAB
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+    Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
+    TS_ASSERT(response9.d_node == unionDisjointBA
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_disjoint A B) A) = (union_disjoint A B)
+    Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
+    TS_ASSERT(response10.d_node == unionDisjointAB
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_max (union_disjoint B A) A) = (union_disjoint B A)
+    Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
+    TS_ASSERT(response11.d_node == unionDisjointBA
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testUnionDisjoint()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (union_disjoint A emptybag) = A
+    Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint emptybag A) = A
+    Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
+    TS_ASSERT(response2.d_node == A
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint (union_max A B) (intersection_min B A)) =
+    //          (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+    Node unionDisjoint3 =
+        d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
+    RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
+    TS_ASSERT(response3.d_node == unionDisjointAB
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint (intersection_min B A)) (union_max A B) =
+    //          (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+    Node unionDisjoint4 =
+        d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
+    RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
+    TS_ASSERT(response4.d_node == unionDisjointBA
+              && response4.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testIntersectionMin()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+    // (intersection_min A emptybag) = emptyBag
+    Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == emptyBag
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A A) = A
+    Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == A
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_max A B) = A
+    Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
+    RewriteResponse response4 = d_rewriter->postRewrite(n4);
+    TS_ASSERT(response4.d_node == A
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_max B A) = A
+    Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
+    RewriteResponse response5 = d_rewriter->postRewrite(n5);
+    TS_ASSERT(response5.d_node == A
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_max A B) A) = A
+    Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == A
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_max B A) A) = A
+    Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == A
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_disjoint A B) = A
+    Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == A
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min A (union_disjoint B A) = A
+    Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == A
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_disjoint A B) A) = A
+    Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == A
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (intersection_min (union_disjoint B A) A) = A
+    Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == A
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testDifferenceSubtract()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (difference_subtract A emptybag) = A
+    Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A A) = emptybag
+    Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == emptyBag
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (union_disjoint A B) A) = B
+    Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
+    RewriteResponse response4 = d_rewriter->postRewrite(n4);
+    TS_ASSERT(response4.d_node == B
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (union_disjoint B A) A) = B
+    Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
+    RewriteResponse response5 = d_rewriter->postRewrite(n5);
+    TS_ASSERT(response5.d_node == B
+              && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_disjoint A B)) = emptybag
+    Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == emptyBag
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_disjoint B A)) = emptybag
+    Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == emptyBag
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_max A B)) = emptybag
+    Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == emptyBag
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract A (union_max B A)) = emptybag
+    Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == emptyBag
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (intersection_min A B) A) = emptybag
+    Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == emptyBag
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_subtract (intersection_min B A) A) = emptybag
+    Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == emptyBag
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testDifferenceRemove()
+  {
+    int n = 3;
+    vector<Node> elements = getNStrings(2);
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+    Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+    // (difference_remove A emptybag) = A
+    Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == A
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove emptybag A) = emptyBag
+    Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == emptyBag
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A A) = emptybag
+    Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == emptyBag
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_disjoint A B)) = emptybag
+    Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
+    RewriteResponse response6 = d_rewriter->postRewrite(n6);
+    TS_ASSERT(response6.d_node == emptyBag
+              && response6.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_disjoint B A)) = emptybag
+    Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
+    RewriteResponse response7 = d_rewriter->postRewrite(n7);
+    TS_ASSERT(response7.d_node == emptyBag
+              && response7.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_max A B)) = emptybag
+    Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
+    RewriteResponse response8 = d_rewriter->postRewrite(n8);
+    TS_ASSERT(response8.d_node == emptyBag
+              && response8.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove A (union_max B A)) = emptybag
+    Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
+    RewriteResponse response9 = d_rewriter->postRewrite(n9);
+    TS_ASSERT(response9.d_node == emptyBag
+              && response9.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove (intersection_min A B) A) = emptybag
+    Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
+    RewriteResponse response10 = d_rewriter->postRewrite(n10);
+    TS_ASSERT(response10.d_node == emptyBag
+              && response10.d_status == REWRITE_AGAIN_FULL);
+
+    // (difference_remove (intersection_min B A) A) = emptybag
+    Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
+    RewriteResponse response11 = d_rewriter->postRewrite(n11);
+    TS_ASSERT(response11.d_node == emptyBag
+              && response11.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testChoose()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node c = d_nm->mkConst(Rational(3));
+    Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+
+    // (bag.choose (mkBag x c)) = x where c is a constant > 0
+    Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
+    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    TS_ASSERT(response1.d_node == x
+              && response1.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testBagCard()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node emptyBag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node zero = d_nm->mkConst(Rational(0));
+    Node c = d_nm->mkConst(Rational(3));
+    Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+    vector<Node> elements = getNStrings(2);
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(5)));
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+    // TODO(projects#223): enable this test after implementing bags normal form
+    //    // (bag.card emptybag) = 0
+    //    Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
+    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    //    TS_ASSERT(response1.d_node == zero && response1.d_status ==
+    //    REWRITE_AGAIN_FULL);
+
+    // (bag.card (mkBag x c)) = c where c is a constant > 0
+    Node n2 = d_nm->mkNode(BAG_CARD, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == c
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+    Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
+    Node cardA = d_nm->mkNode(BAG_CARD, A);
+    Node cardB = d_nm->mkNode(BAG_CARD, B);
+    Node plus = d_nm->mkNode(PLUS, cardA, cardB);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == plus
+              && response3.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testIsSingleton()
+  {
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node c = d_nm->mkSkolem("c", d_nm->integerType());
+    Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+
+    // TODO(projects#223): complete this function
+    // (bag.is_singleton emptybag) = false
+    //    Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
+    //    RewriteResponse response1 = d_rewriter->postRewrite(n1);
+    //    TS_ASSERT(response1.d_node == d_nm->mkConst(false)
+    //              && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (bag.is_singleton (mkBag x c) = (c == 1)
+    Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    Node one = d_nm->mkConst(Rational(1));
+    Node equal = c.eqNode(one);
+    TS_ASSERT(response2.d_node == equal
+              && response2.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testFromSet()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
+
+    // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
+    Node n = d_nm->mkNode(BAG_FROM_SET, singleton);
+    RewriteResponse response = d_rewriter->postRewrite(n);
+    Node one = d_nm->mkConst(Rational(1));
+    Node bag = d_nm->mkBag(d_nm->stringType(), x, one);
+    TS_ASSERT(response.d_node == bag
+              && response.d_status == REWRITE_AGAIN_FULL);
+  }
+
+  void testToSet()
+  {
+    Node x = d_nm->mkSkolem("x", d_nm->stringType());
+    Node bag = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(5)));
+
+    // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
+    Node n = d_nm->mkNode(BAG_TO_SET, bag);
+    RewriteResponse response = d_rewriter->postRewrite(n);
+    Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
+    TS_ASSERT(response.d_node == singleton
+              && response.d_status == REWRITE_AGAIN_FULL);
+  }
+
+ private:
+  std::unique_ptr<ExprManager> d_em;
+  std::unique_ptr<SmtEngine> d_smt;
+  std::unique_ptr<NodeManager> d_nm;
+
+  std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */
diff --git a/test/unit/theory/theory_bags_type_rules_black.h b/test/unit/theory/theory_bags_type_rules_black.h
deleted file mode 100644 (file)
index d6c225b..0000000
+++ /dev/null
@@ -1,111 +0,0 @@
-/*********************                                                        */
-/*! \file theory_bags_type_rules_black.h
- ** \verbatim
- ** Top contributors (to current version):
- **   Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Black box testing of bags typing rules
- **/
-
-#include <cxxtest/TestSuite.h>
-
-#include "expr/dtype.h"
-#include "smt/smt_engine.h"
-#include "theory/bags/theory_bags_type_rules.h"
-#include "theory/strings/type_enumerator.h"
-
-using namespace CVC4;
-using namespace CVC4::smt;
-using namespace CVC4::theory;
-using namespace CVC4::kind;
-using namespace CVC4::theory::bags;
-using namespace std;
-
-typedef expr::Attribute<Node, Node> attribute;
-
-class BagsTypeRuleBlack : public CxxTest::TestSuite
-{
- public:
-  void setUp() override
-  {
-    d_em.reset(new ExprManager());
-    d_smt.reset(new SmtEngine(d_em.get()));
-    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
-    d_smt->finishInit();
-  }
-
-  void tearDown() override
-  {
-    d_smt.reset();
-    d_nm.release();
-    d_em.reset();
-  }
-
-  std::vector<Node> getNStrings(size_t n)
-  {
-    std::vector<Node> elements(n);
-    CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
-
-    for (size_t i = 0; i < n; i++)
-    {
-      ++enumerator;
-      elements[i] = *enumerator;
-    }
-
-    return elements;
-  }
-
-  void testCountOperator()
-  {
-    vector<Node> elements = getNStrings(1);
-    Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(100)));
-
-    Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag);
-    Node node = d_nm->mkConst(Rational(10));
-
-    // node of type Int is not compatible with bag of type (Bag String)
-    TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true),
-                     TypeCheckingExceptionPrivate&);
-  }
-
-  void testMkBagOperator()
-  {
-    vector<Node> elements = getNStrings(1);
-    Node negative =
-        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
-    Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
-    Node positive =
-        d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
-
-    // only positive multiplicity are constants
-    TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative));
-    TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero));
-    TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive));
-  }
-
-  void testFromSetOperator()
-  {
-    vector<Node> elements = getNStrings(1);
-    Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]);
-    TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set));
-    TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag());
-  }
-
-  void testToSetOperator()
-  {
-    vector<Node> elements = getNStrings(1);
-    Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(10)));
-    TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
-    TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
-  }
-
- private:
-  std::unique_ptr<ExprManager> d_em;
-  std::unique_ptr<SmtEngine> d_smt;
-  std::unique_ptr<NodeManager> d_nm;
-}; /* class BagsTypeRuleBlack */
diff --git a/test/unit/theory/theory_bags_type_rules_white.h b/test/unit/theory/theory_bags_type_rules_white.h
new file mode 100644 (file)
index 0000000..dfe2d4c
--- /dev/null
@@ -0,0 +1,113 @@
+/*********************                                                        */
+/*! \file theory_bags_type_rules_black.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Black box testing of bags typing rules
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/theory_bags_type_rules.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleWhite : public CxxTest::TestSuite
+{
+ public:
+  void setUp() override
+  {
+    d_em.reset(new ExprManager());
+    d_smt.reset(new SmtEngine(d_em.get()));
+    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+    d_smt->finishInit();
+  }
+
+  void tearDown() override
+  {
+    d_smt.reset();
+    d_nm.release();
+    d_em.reset();
+  }
+
+  std::vector<Node> getNStrings(size_t n)
+  {
+    std::vector<Node> elements(n);
+    CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
+
+    for (size_t i = 0; i < n; i++)
+    {
+      ++enumerator;
+      elements[i] = *enumerator;
+    }
+
+    return elements;
+  }
+
+  void testCountOperator()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node bag = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(100)));
+
+    Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag);
+    Node node = d_nm->mkConst(Rational(10));
+
+    // node of type Int is not compatible with bag of type (Bag String)
+    TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true),
+                     TypeCheckingExceptionPrivate&);
+  }
+
+  void testMkBagOperator()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node negative = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+    Node zero = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+    Node positive = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+
+    // only positive multiplicity are constants
+    TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative));
+    TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero));
+    TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive));
+  }
+
+  void testFromSetOperator()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]);
+    TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set));
+    TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag());
+  }
+
+  void testToSetOperator()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10)));
+    TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
+    TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
+  }
+
+ private:
+  std::unique_ptr<ExprManager> d_em;
+  std::unique_ptr<SmtEngine> d_smt;
+  std::unique_ptr<NodeManager> d_nm;
+}; /* class BagsTypeRuleBlack */