This PR removes subtyping rules for bags and add operator MakeBagOp similar to SingletonOp
theory/bags/bags_statistics.h
theory/bags/inference_manager.cpp
theory/bags/inference_manager.h
+ theory/bags/make_bag_op.cpp
+ theory/bags/make_bag_op.h
theory/bags/normal_form.cpp
theory/bags/normal_form.h
theory/bags/rewrites.cpp
*/
MEMBER,
/**
- * The set of the single element given as a parameter.
+ * Construct a singleton set from an element given as a parameter.
+ * The returned set has same type of the element.
* Parameters: 1
* -[1]: Single element
* Create with:
return singleton;
}
+Node NodeManager::mkBag(const TypeNode& t, const TNode n, const TNode m)
+{
+ Assert(n.getType().isSubtypeOf(t))
+ << "Invalid operands for mkBag. The type '" << n.getType()
+ << "' of node '" << n << "' is not a subtype of '" << t << "'."
+ << std::endl;
+ Node op = mkConst(MakeBagOp(t));
+ Node bag = mkNode(kind::MK_BAG, op, n, m);
+ return bag;
+}
+
Node NodeManager::mkAbstractValue(const TypeNode& type) {
Node n = mkConst(AbstractValue(++d_abstractValueCount));
n.setAttribute(TypeAttr(), type);
/**
* Create a singleton set from the given element n.
* @param t the element type of the returned set.
- * Note that the type of n needs to be a subtype of t.
+ * Note that the type of n needs to be a subtype of t.
* @param n the single element in the singleton.
* @return a singleton set constructed from the element n.
*/
Node mkSingleton(const TypeNode& t, const TNode n);
+ /**
+ * Create a bag from the given element n along with its multiplicity m.
+ * @param t the element type of the returned bag.
+ * Note that the type of n needs to be a subtype of t.
+ * @param n the element that is used to to construct the bag
+ * @param m the multiplicity of the element n
+ * @return a bag that contains m occurrences of n.
+ */
+ Node mkBag(const TypeNode& t, const TNode n, const TNode m);
+
/**
* Create a constant of type T. It will have the appropriate
* CONST_* kind defined for T.
case kind::ARRAY_TYPE:
case kind::DATATYPE_TYPE:
case kind::PARAMETRIC_DATATYPE:
- case kind::SEQUENCE_TYPE: return TypeNode();
+ case kind::SEQUENCE_TYPE:
case kind::SET_TYPE:
+ case kind::BAG_TYPE:
{
- // we don't support subtyping for sets
- return TypeNode(); // return null type
+ // we don't support subtyping except for built in types Int and Real.
+ return TypeNode(); // return null type
}
case kind::SEXPR_TYPE:
Unimplemented()
{
// (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
Node one = d_nm->mkConst(Rational(1));
- Node bag = d_nm->mkNode(MK_BAG, n[0][0], one);
+ TypeNode type = n[0].getType().getSetElementType();
+ Node bag = d_nm->mkBag(type, n[0][0], one);
return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
}
return BagsRewriteResponse(n, Rewrite::NONE);
operator BAG_IS_INCLUDED 2 "inclusion predicate for bags (less than or equal multiplicities)"
operator BAG_COUNT 2 "multiplicity of an element in a bag"
-operator MK_BAG 2 "constructs a bag from one element along with its multiplicity"
+
+constant MK_BAG_OP \
+ ::CVC4::MakeBagOp \
+ ::CVC4::MakeBagOpHashFunction \
+ "theory/bags/make_bag_op.h" \
+ "operator for MK_BAG; payload is an instance of the CVC4::MakeBagOp class"
+parameterized MK_BAG MK_BAG_OP 2 \
+"constructs a bag from one element along with its multiplicity"
# The operator bag-is-singleton returns whether the given bag is a singleton
operator BAG_IS_SINGLETON 1 "return whether the given bag is a singleton"
typerule DIFFERENCE_REMOVE ::CVC4::theory::bags::BinaryOperatorTypeRule
typerule BAG_IS_INCLUDED ::CVC4::theory::bags::IsIncludedTypeRule
typerule BAG_COUNT ::CVC4::theory::bags::CountTypeRule
+typerule MK_BAG_OP "SimpleTypeRule<RBuiltinOperator>"
typerule MK_BAG ::CVC4::theory::bags::MkBagTypeRule
typerule EMPTYBAG ::CVC4::theory::bags::EmptyBagTypeRule
typerule BAG_CARD ::CVC4::theory::bags::CardTypeRule
--- /dev/null
+/********************* */
+/*! \file bag_op.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief a class for MK_BAG operator
+ **/
+
+#include <iostream>
+
+#include "expr/type_node.h"
+#include "make_bag_op.h"
+
+namespace CVC4 {
+
+std::ostream& operator<<(std::ostream& out, const MakeBagOp& op)
+{
+ return out << "(mkBag_op " << op.getType() << ')';
+}
+
+size_t MakeBagOpHashFunction::operator()(const MakeBagOp& op) const
+{
+ return TypeNodeHashFunction()(op.getType());
+}
+
+MakeBagOp::MakeBagOp(const TypeNode& elementType)
+ : d_type(new TypeNode(elementType))
+{
+}
+
+MakeBagOp::MakeBagOp(const MakeBagOp& op) : d_type(new TypeNode(op.getType()))
+{
+}
+
+const TypeNode& MakeBagOp::getType() const { return *d_type; }
+
+bool MakeBagOp::operator==(const MakeBagOp& op) const
+{
+ return getType() == op.getType();
+}
+
+} // namespace CVC4
--- /dev/null
+/********************* */
+/*! \file mk_bag_op.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief a class for MK_BAG operator
+ **/
+
+#include "cvc4_public.h"
+
+#ifndef CVC4__MAKE_BAG_OP_H
+#define CVC4__MAKE_BAG_OP_H
+
+#include <memory>
+
+namespace CVC4 {
+
+class TypeNode;
+
+/**
+ * The class is an operator for kind MK_BAG used to construct bags.
+ * It specifies the type of the element especially when it is a constant.
+ * e.g. the type of rational 1 is Int, however
+ * (mkBag (mkBag_op Real) 1) is of type (Bag Real), not (Bag Int).
+ * Note that the type passed to the constructor is the element's type, not the
+ * bag type.
+ */
+class MakeBagOp
+{
+ public:
+ MakeBagOp(const TypeNode& elementType);
+ MakeBagOp(const MakeBagOp& op);
+
+ /** return the type of the current object */
+ const TypeNode& getType() const;
+
+ bool operator==(const MakeBagOp& op) const;
+
+ private:
+ MakeBagOp();
+ /** a pointer to the type of the bag element */
+ std::unique_ptr<TypeNode> d_type;
+}; /* class MakeBagOp */
+
+std::ostream& operator<<(std::ostream& out, const MakeBagOp& op);
+
+/**
+ * Hash function for the MakeBagOpHashFunction objects.
+ */
+struct CVC4_PUBLIC MakeBagOpHashFunction
+{
+ size_t operator()(const MakeBagOp& op) const;
+}; /* struct MakeBagOpHashFunction */
+
+} // namespace CVC4
+
+#endif /* CVC4__MAKE_BAG_OP_H */
{
// increase the multiplicity by one
Node one = d_nodeManager->mkConst(Rational(1));
- Node singleton = d_nodeManager->mkNode(kind::MK_BAG, d_element, one);
+ TypeNode elementType = d_elementTypeEnumerator.getType();
+ Node singleton = d_nodeManager->mkBag(elementType, d_element, one);
if (d_currentBag.getKind() == kind::EMPTYBAG)
{
d_currentBag = singleton;
TypeNode secondBagType = n[1].getType(check);
if (secondBagType != bagType)
{
- if (n.getKind() == kind::INTERSECTION_MIN)
- {
- bagType = TypeNode::mostCommonTypeNode(secondBagType, bagType);
- }
- else
- {
- bagType = TypeNode::leastCommonTypeNode(secondBagType, bagType);
- }
- if (bagType.isNull())
- {
- throw TypeCheckingExceptionPrivate(
- n, "operator expects two bags of comparable types");
- }
+ std::stringstream ss;
+ ss << "Operator " << n.getKind()
+ << " expects two bags of the same type. Found types '" << bagType
+ << "' and '" << secondBagType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
}
}
return bagType;
n, "checking for membership in a non-bag");
}
TypeNode elementType = n[0].getType(check);
- // TODO(projects#226): comments from sets
- //
- // T : (Bag Int)
- // B : (Bag Real)
- // (= (as T (Bag Real)) B)
- // (= (bag-count 0.5 B) 1)
- // ...where (bag-count 0.5 T) is inferred
-
- if (!elementType.isComparableTo(bagType.getBagElementType()))
+ // e.g. (count 1 (mkBag (mkBag_op Real) 1.0 3))) is 3 whereas
+ // (count 1.0 (mkBag (mkBag_op Int) 1 3))) throws a typing error
+ if (!elementType.isSubtypeOf(bagType.getBagElementType()))
{
std::stringstream ss;
ss << "member operating on bags of different types:\n"
{
static TypeNode computeType(NodeManager* nm, TNode n, bool check)
{
- Assert(n.getKind() == kind::MK_BAG);
+ Assert(n.getKind() == kind::MK_BAG && n.hasOperator()
+ && n.getOperator().getKind() == kind::MK_BAG_OP);
+ MakeBagOp op = n.getOperator().getConst<MakeBagOp>();
+ TypeNode expectedElementType = op.getType();
if (check)
{
if (n.getNumChildren() != 2)
ss << "MK_BAG expects an integer for " << n[1] << ". Found" << type1;
throw TypeCheckingExceptionPrivate(n, ss.str());
}
+
+ TypeNode actualElementType = n[0].getType(check);
+ // the type of the element should be a subtype of the type of the operator
+ // e.g. (mkBag (mkBag_op Real) 1 1) where 1 is an Int
+ if (!actualElementType.isSubtypeOf(expectedElementType))
+ {
+ std::stringstream ss;
+ ss << "The type '" << actualElementType
+ << "' of the element is not a subtype of '" << expectedElementType
+ << "' in term : " << n;
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
}
- return nm->mkBagType(n[0].getType(check));
+ return nm->mkBagType(expectedElementType);
}
static bool computeIsConst(NodeManager* nodeManager, TNode n)
cvc4_add_unit_test_white(logic_info_white theory)
cvc4_add_unit_test_white(sequences_rewriter_white theory)
cvc4_add_unit_test_white(theory_arith_white theory)
-cvc4_add_unit_test_white(theory_bags_rewriter_black theory)
-cvc4_add_unit_test_white(theory_bags_type_rules_black theory)
+cvc4_add_unit_test_white(theory_bags_rewriter_white theory)
+cvc4_add_unit_test_white(theory_bags_type_rules_white theory)
cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
cvc4_add_unit_test_white(theory_bv_white theory)
cvc4_add_unit_test_white(theory_engine_white theory)
+++ /dev/null
-/********************* */
-/*! \file theory_bags_rewriter_black.h
- ** \verbatim
- ** Top contributors (to current version):
- ** Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Black box testing of bags rewriter
- **/
-
-#include <cxxtest/TestSuite.h>
-
-#include "expr/dtype.h"
-#include "smt/smt_engine.h"
-#include "theory/bags/bags_rewriter.h"
-#include "theory/strings/type_enumerator.h"
-
-using namespace CVC4;
-using namespace CVC4::smt;
-using namespace CVC4::theory;
-using namespace CVC4::kind;
-using namespace CVC4::theory::bags;
-using namespace std;
-
-typedef expr::Attribute<Node, Node> attribute;
-
-class BagsTypeRuleBlack : public CxxTest::TestSuite
-{
- public:
- void setUp() override
- {
- d_em.reset(new ExprManager());
- d_smt.reset(new SmtEngine(d_em.get()));
- d_nm.reset(NodeManager::fromExprManager(d_em.get()));
- d_smt->finishInit();
- d_rewriter.reset(new BagsRewriter(nullptr));
- }
-
- void tearDown() override
- {
- d_rewriter.reset();
- d_smt.reset();
- d_nm.release();
- d_em.reset();
- }
-
- std::vector<Node> getNStrings(size_t n)
- {
- std::vector<Node> elements(n);
- for (size_t i = 0; i < n; i++)
- {
- elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
- }
- return elements;
- }
-
- void testEmptyBagNormalForm()
- {
- Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
- // empty bags are in normal form
- TS_ASSERT(emptybag.isConst());
- RewriteResponse response = d_rewriter->postRewrite(emptybag);
- TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
- }
-
- void testBagEquality()
- {
- vector<Node> elements = getNStrings(2);
- Node x = elements[0];
- Node y = elements[1];
- Node c = d_nm->mkSkolem("c", d_nm->integerType());
- Node d = d_nm->mkSkolem("d", d_nm->integerType());
- Node bagX = d_nm->mkNode(MK_BAG, x, c);
- Node bagY = d_nm->mkNode(MK_BAG, y, d);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
-
- // (= A A) = true where A is a bag
- Node n1 = emptyBag.eqNode(emptyBag);
- RewriteResponse response1 = d_rewriter->preRewrite(n1);
- TS_ASSERT(response1.d_node == d_nm->mkConst(true)
- && response1.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testMkBagConstantElement()
- {
- vector<Node> elements = getNStrings(1);
- Node negative =
- d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
- Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
- Node positive =
- d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
- Node emptybag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
- RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
- RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
-
- // bags with non-positive multiplicity are rewritten as empty bags
- TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
- && negativeResponse.d_node == emptybag);
- TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
- && zeroResponse.d_node == emptybag);
-
- // no change for positive
- TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
- && positive == positiveResponse.d_node);
- }
-
- void testMkBagVariableElement()
- {
- Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
- Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
- Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
- Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0)));
- Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1)));
- Node emptybag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
- RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
- RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
-
- // bags with non-positive multiplicity are rewritten as empty bags
- TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
- && negativeResponse.d_node == emptybag);
- TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
- && zeroResponse.d_node == emptybag);
-
- // no change for positive
- TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
- && positive == positiveResponse.d_node);
- }
-
- void testBagCount()
- {
- int n = 3;
- Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
- Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
- Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n)));
-
- // (bag.count x emptybag) = 0
- Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(n1);
- TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
- && response1.d_node == d_nm->mkConst(Rational(0)));
-
- // (bag.count x (mkBag x c) = c where c > 0 is a constant
- Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
- && response2.d_node == d_nm->mkConst(Rational(n)));
- }
-
- void testUnionMax()
- {
- int n = 3;
- vector<Node> elements = getNStrings(2);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
- Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
- Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
- Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-
- // (union_max A emptybag) = A
- Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
- TS_ASSERT(response1.d_node == A
- && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max emptybag A) = A
- Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
- RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
- TS_ASSERT(response2.d_node == A
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max A A) = A
- Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
- RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
- TS_ASSERT(response3.d_node == A
- && response3.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max A (union_max A B)) = (union_max A B)
- Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
- RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
- TS_ASSERT(response4.d_node == unionMaxAB
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max A (union_max B A)) = (union_max B A)
- Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
- RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
- TS_ASSERT(response5.d_node == unionMaxBA
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max (union_max A B) A) = (union_max A B)
- Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
- RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
- TS_ASSERT(response6.d_node == unionMaxAB
- && response6.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max (union_max B A) A) = (union_max B A)
- Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
- RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
- TS_ASSERT(response7.d_node == unionMaxBA
- && response7.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max A (union_disjoint A B)) = (union_disjoint A B)
- Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
- RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
- TS_ASSERT(response8.d_node == unionDisjointAB
- && response8.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max A (union_disjoint B A)) = (union_disjoint B A)
- Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
- RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
- TS_ASSERT(response9.d_node == unionDisjointBA
- && response9.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max (union_disjoint A B) A) = (union_disjoint A B)
- Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
- RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
- TS_ASSERT(response10.d_node == unionDisjointAB
- && response10.d_status == REWRITE_AGAIN_FULL);
-
- // (union_max (union_disjoint B A) A) = (union_disjoint B A)
- Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
- RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
- TS_ASSERT(response11.d_node == unionDisjointBA
- && response11.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testUnionDisjoint()
- {
- int n = 3;
- vector<Node> elements = getNStrings(2);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
- Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
- Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
- Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
- Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
- Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
- // (union_disjoint A emptybag) = A
- Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
- TS_ASSERT(response1.d_node == A
- && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (union_disjoint emptybag A) = A
- Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
- RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
- TS_ASSERT(response2.d_node == A
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (union_disjoint (union_max A B) (intersection_min B A)) =
- // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
- Node unionDisjoint3 =
- d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
- RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
- TS_ASSERT(response3.d_node == unionDisjointAB
- && response3.d_status == REWRITE_AGAIN_FULL);
-
- // (union_disjoint (intersection_min B A)) (union_max A B) =
- // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
- Node unionDisjoint4 =
- d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
- RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
- TS_ASSERT(response4.d_node == unionDisjointBA
- && response4.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testIntersectionMin()
- {
- int n = 3;
- vector<Node> elements = getNStrings(2);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
- Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
- Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
- Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
-
- // (intersection_min A emptybag) = emptyBag
- Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(n1);
- TS_ASSERT(response1.d_node == emptyBag
- && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min emptybag A) = emptyBag
- Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- TS_ASSERT(response2.d_node == emptyBag
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min A A) = A
- Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
- RewriteResponse response3 = d_rewriter->postRewrite(n3);
- TS_ASSERT(response3.d_node == A
- && response3.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min A (union_max A B) = A
- Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
- RewriteResponse response4 = d_rewriter->postRewrite(n4);
- TS_ASSERT(response4.d_node == A
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min A (union_max B A) = A
- Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
- RewriteResponse response5 = d_rewriter->postRewrite(n5);
- TS_ASSERT(response5.d_node == A
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min (union_max A B) A) = A
- Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
- RewriteResponse response6 = d_rewriter->postRewrite(n6);
- TS_ASSERT(response6.d_node == A
- && response6.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min (union_max B A) A) = A
- Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
- RewriteResponse response7 = d_rewriter->postRewrite(n7);
- TS_ASSERT(response7.d_node == A
- && response7.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min A (union_disjoint A B) = A
- Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
- RewriteResponse response8 = d_rewriter->postRewrite(n8);
- TS_ASSERT(response8.d_node == A
- && response8.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min A (union_disjoint B A) = A
- Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
- RewriteResponse response9 = d_rewriter->postRewrite(n9);
- TS_ASSERT(response9.d_node == A
- && response9.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min (union_disjoint A B) A) = A
- Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
- RewriteResponse response10 = d_rewriter->postRewrite(n10);
- TS_ASSERT(response10.d_node == A
- && response10.d_status == REWRITE_AGAIN_FULL);
-
- // (intersection_min (union_disjoint B A) A) = A
- Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
- RewriteResponse response11 = d_rewriter->postRewrite(n11);
- TS_ASSERT(response11.d_node == A
- && response11.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testDifferenceSubtract()
- {
- int n = 3;
- vector<Node> elements = getNStrings(2);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
- Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
- Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
- Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
- Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
- Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
- // (difference_subtract A emptybag) = A
- Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(n1);
- TS_ASSERT(response1.d_node == A
- && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract emptybag A) = emptyBag
- Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- TS_ASSERT(response2.d_node == emptyBag
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract A A) = emptybag
- Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
- RewriteResponse response3 = d_rewriter->postRewrite(n3);
- TS_ASSERT(response3.d_node == emptyBag
- && response3.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract (union_disjoint A B) A) = B
- Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
- RewriteResponse response4 = d_rewriter->postRewrite(n4);
- TS_ASSERT(response4.d_node == B
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract (union_disjoint B A) A) = B
- Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
- RewriteResponse response5 = d_rewriter->postRewrite(n5);
- TS_ASSERT(response5.d_node == B
- && response4.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract A (union_disjoint A B)) = emptybag
- Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
- RewriteResponse response6 = d_rewriter->postRewrite(n6);
- TS_ASSERT(response6.d_node == emptyBag
- && response6.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract A (union_disjoint B A)) = emptybag
- Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
- RewriteResponse response7 = d_rewriter->postRewrite(n7);
- TS_ASSERT(response7.d_node == emptyBag
- && response7.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract A (union_max A B)) = emptybag
- Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
- RewriteResponse response8 = d_rewriter->postRewrite(n8);
- TS_ASSERT(response8.d_node == emptyBag
- && response8.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract A (union_max B A)) = emptybag
- Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
- RewriteResponse response9 = d_rewriter->postRewrite(n9);
- TS_ASSERT(response9.d_node == emptyBag
- && response9.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract (intersection_min A B) A) = emptybag
- Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
- RewriteResponse response10 = d_rewriter->postRewrite(n10);
- TS_ASSERT(response10.d_node == emptyBag
- && response10.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_subtract (intersection_min B A) A) = emptybag
- Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
- RewriteResponse response11 = d_rewriter->postRewrite(n11);
- TS_ASSERT(response11.d_node == emptyBag
- && response11.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testDifferenceRemove()
- {
- int n = 3;
- vector<Node> elements = getNStrings(2);
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
- Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
- Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
- Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
- Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
- Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
-
- // (difference_remove A emptybag) = A
- Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
- RewriteResponse response1 = d_rewriter->postRewrite(n1);
- TS_ASSERT(response1.d_node == A
- && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove emptybag A) = emptyBag
- Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- TS_ASSERT(response2.d_node == emptyBag
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove A A) = emptybag
- Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
- RewriteResponse response3 = d_rewriter->postRewrite(n3);
- TS_ASSERT(response3.d_node == emptyBag
- && response3.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove A (union_disjoint A B)) = emptybag
- Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
- RewriteResponse response6 = d_rewriter->postRewrite(n6);
- TS_ASSERT(response6.d_node == emptyBag
- && response6.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove A (union_disjoint B A)) = emptybag
- Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
- RewriteResponse response7 = d_rewriter->postRewrite(n7);
- TS_ASSERT(response7.d_node == emptyBag
- && response7.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove A (union_max A B)) = emptybag
- Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
- RewriteResponse response8 = d_rewriter->postRewrite(n8);
- TS_ASSERT(response8.d_node == emptyBag
- && response8.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove A (union_max B A)) = emptybag
- Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
- RewriteResponse response9 = d_rewriter->postRewrite(n9);
- TS_ASSERT(response9.d_node == emptyBag
- && response9.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove (intersection_min A B) A) = emptybag
- Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
- RewriteResponse response10 = d_rewriter->postRewrite(n10);
- TS_ASSERT(response10.d_node == emptyBag
- && response10.d_status == REWRITE_AGAIN_FULL);
-
- // (difference_remove (intersection_min B A) A) = emptybag
- Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
- RewriteResponse response11 = d_rewriter->postRewrite(n11);
- TS_ASSERT(response11.d_node == emptyBag
- && response11.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testChoose()
- {
- Node x = d_nm->mkSkolem("x", d_nm->stringType());
- Node c = d_nm->mkConst(Rational(3));
- Node bag = d_nm->mkNode(MK_BAG, x, c);
-
- // (bag.choose (mkBag x c)) = x where c is a constant > 0
- Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
- RewriteResponse response1 = d_rewriter->postRewrite(n1);
- TS_ASSERT(response1.d_node == x
- && response1.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testBagCard()
- {
- Node x = d_nm->mkSkolem("x", d_nm->stringType());
- Node emptyBag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node zero = d_nm->mkConst(Rational(0));
- Node c = d_nm->mkConst(Rational(3));
- Node bag = d_nm->mkNode(MK_BAG, x, c);
- vector<Node> elements = getNStrings(2);
- Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4)));
- Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5)));
- Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
-
- // TODO(projects#223): enable this test after implementing bags normal form
- // // (bag.card emptybag) = 0
- // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
- // RewriteResponse response1 = d_rewriter->postRewrite(n1);
- // TS_ASSERT(response1.d_node == zero && response1.d_status ==
- // REWRITE_AGAIN_FULL);
-
- // (bag.card (mkBag x c)) = c where c is a constant > 0
- Node n2 = d_nm->mkNode(BAG_CARD, bag);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- TS_ASSERT(response2.d_node == c
- && response2.d_status == REWRITE_AGAIN_FULL);
-
- // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
- Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
- Node cardA = d_nm->mkNode(BAG_CARD, A);
- Node cardB = d_nm->mkNode(BAG_CARD, B);
- Node plus = d_nm->mkNode(PLUS, cardA, cardB);
- RewriteResponse response3 = d_rewriter->postRewrite(n3);
- TS_ASSERT(response3.d_node == plus
- && response3.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testIsSingleton()
- {
- Node emptybag =
- d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
- Node x = d_nm->mkSkolem("x", d_nm->stringType());
- Node c = d_nm->mkSkolem("c", d_nm->integerType());
- Node bag = d_nm->mkNode(MK_BAG, x, c);
-
- // TODO(projects#223): complete this function
- // (bag.is_singleton emptybag) = false
- // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
- // RewriteResponse response1 = d_rewriter->postRewrite(n1);
- // TS_ASSERT(response1.d_node == d_nm->mkConst(false)
- // && response1.d_status == REWRITE_AGAIN_FULL);
-
- // (bag.is_singleton (mkBag x c) = (c == 1)
- Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
- RewriteResponse response2 = d_rewriter->postRewrite(n2);
- Node one = d_nm->mkConst(Rational(1));
- Node equal = c.eqNode(one);
- TS_ASSERT(response2.d_node == equal
- && response2.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testFromSet()
- {
- Node x = d_nm->mkSkolem("x", d_nm->stringType());
- Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
-
- // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
- Node n = d_nm->mkNode(BAG_FROM_SET, singleton);
- RewriteResponse response = d_rewriter->postRewrite(n);
- Node one = d_nm->mkConst(Rational(1));
- Node bag = d_nm->mkNode(MK_BAG, x, one);
- TS_ASSERT(response.d_node == bag
- && response.d_status == REWRITE_AGAIN_FULL);
- }
-
- void testToSet()
- {
- Node x = d_nm->mkSkolem("x", d_nm->stringType());
- Node bag = d_nm->mkNode(MK_BAG, x, d_nm->mkConst(Rational(5)));
-
- // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
- Node n = d_nm->mkNode(BAG_TO_SET, bag);
- RewriteResponse response = d_rewriter->postRewrite(n);
- Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
- TS_ASSERT(response.d_node == singleton
- && response.d_status == REWRITE_AGAIN_FULL);
- }
-
- private:
- std::unique_ptr<ExprManager> d_em;
- std::unique_ptr<SmtEngine> d_smt;
- std::unique_ptr<NodeManager> d_nm;
-
- std::unique_ptr<BagsRewriter> d_rewriter;
-}; /* class BagsTypeRuleBlack */
--- /dev/null
+/********************* */
+/*! \file theory_bags_rewriter_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief White box testing of bags rewriter
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleWhite : public CxxTest::TestSuite
+{
+ public:
+ void setUp() override
+ {
+ d_em.reset(new ExprManager());
+ d_smt.reset(new SmtEngine(d_em.get()));
+ d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+ d_smt->finishInit();
+ d_rewriter.reset(new BagsRewriter(nullptr));
+ }
+
+ void tearDown() override
+ {
+ d_rewriter.reset();
+ d_smt.reset();
+ d_nm.release();
+ d_em.reset();
+ }
+
+ std::vector<Node> getNStrings(size_t n)
+ {
+ std::vector<Node> elements(n);
+ for (size_t i = 0; i < n; i++)
+ {
+ elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
+ }
+ return elements;
+ }
+
+ void testEmptyBagNormalForm()
+ {
+ Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+ // empty bags are in normal form
+ TS_ASSERT(emptybag.isConst());
+ RewriteResponse response = d_rewriter->postRewrite(emptybag);
+ TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
+ }
+
+ void testBagEquality()
+ {
+ vector<Node> elements = getNStrings(2);
+ Node x = elements[0];
+ Node y = elements[1];
+ Node c = d_nm->mkSkolem("c", d_nm->integerType());
+ Node d = d_nm->mkSkolem("d", d_nm->integerType());
+ Node bagX = d_nm->mkBag(d_nm->stringType(), x, c);
+ Node bagY = d_nm->mkBag(d_nm->stringType(), y, d);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+ // (= A A) = true where A is a bag
+ Node n1 = emptyBag.eqNode(emptyBag);
+ RewriteResponse response1 = d_rewriter->preRewrite(n1);
+ TS_ASSERT(response1.d_node == d_nm->mkConst(true)
+ && response1.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testMkBagConstantElement()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node negative = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+ Node zero = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+ Node positive = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+ RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+ RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+ // bags with non-positive multiplicity are rewritten as empty bags
+ TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+ && negativeResponse.d_node == emptybag);
+ TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+ && zeroResponse.d_node == emptybag);
+
+ // no change for positive
+ TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+ && positive == positiveResponse.d_node);
+ }
+
+ void testMkBagVariableElement()
+ {
+ Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+ Node variable =
+ d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1)));
+ Node negative =
+ d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1)));
+ Node zero =
+ d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(0)));
+ Node positive =
+ d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(1)));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+ RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+ RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+ // bags with non-positive multiplicity are rewritten as empty bags
+ TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+ && negativeResponse.d_node == emptybag);
+ TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+ && zeroResponse.d_node == emptybag);
+
+ // no change for positive
+ TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+ && positive == positiveResponse.d_node);
+ }
+
+ void testBagCount()
+ {
+ int n = 3;
+ Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+ Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
+ Node bag =
+ d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(n)));
+
+ // (bag.count x emptybag) = 0
+ Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
+ && response1.d_node == d_nm->mkConst(Rational(0)));
+
+ // (bag.count x (mkBag x c) = c where c > 0 is a constant
+ Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
+ && response2.d_node == d_nm->mkConst(Rational(n)));
+ }
+
+ void testUnionMax()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+ // (union_max A emptybag) = A
+ Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max emptybag A) = A
+ Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
+ TS_ASSERT(response2.d_node == A
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A A) = A
+ Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
+ TS_ASSERT(response3.d_node == A
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_max A B)) = (union_max A B)
+ Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
+ RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
+ TS_ASSERT(response4.d_node == unionMaxAB
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_max B A)) = (union_max B A)
+ Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
+ RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
+ TS_ASSERT(response5.d_node == unionMaxBA
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_max A B) A) = (union_max A B)
+ Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
+ RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
+ TS_ASSERT(response6.d_node == unionMaxAB
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_max B A) A) = (union_max B A)
+ Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
+ RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
+ TS_ASSERT(response7.d_node == unionMaxBA
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+ Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
+ TS_ASSERT(response8.d_node == unionDisjointAB
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+ Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
+ TS_ASSERT(response9.d_node == unionDisjointBA
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_disjoint A B) A) = (union_disjoint A B)
+ Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
+ TS_ASSERT(response10.d_node == unionDisjointAB
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_disjoint B A) A) = (union_disjoint B A)
+ Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
+ TS_ASSERT(response11.d_node == unionDisjointBA
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testUnionDisjoint()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (union_disjoint A emptybag) = A
+ Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint emptybag A) = A
+ Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
+ TS_ASSERT(response2.d_node == A
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (union_max A B) (intersection_min B A)) =
+ // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint3 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
+ RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
+ TS_ASSERT(response3.d_node == unionDisjointAB
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (intersection_min B A)) (union_max A B) =
+ // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint4 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
+ RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
+ TS_ASSERT(response4.d_node == unionDisjointBA
+ && response4.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testIntersectionMin()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+ // (intersection_min A emptybag) = emptyBag
+ Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == emptyBag
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A A) = A
+ Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == A
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_max A B) = A
+ Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
+ RewriteResponse response4 = d_rewriter->postRewrite(n4);
+ TS_ASSERT(response4.d_node == A
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_max B A) = A
+ Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
+ RewriteResponse response5 = d_rewriter->postRewrite(n5);
+ TS_ASSERT(response5.d_node == A
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_max A B) A) = A
+ Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == A
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_max B A) A) = A
+ Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == A
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_disjoint A B) = A
+ Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == A
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_disjoint B A) = A
+ Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == A
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_disjoint A B) A) = A
+ Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == A
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_disjoint B A) A) = A
+ Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == A
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testDifferenceSubtract()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (difference_subtract A emptybag) = A
+ Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A A) = emptybag
+ Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == emptyBag
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (union_disjoint A B) A) = B
+ Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
+ RewriteResponse response4 = d_rewriter->postRewrite(n4);
+ TS_ASSERT(response4.d_node == B
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (union_disjoint B A) A) = B
+ Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
+ RewriteResponse response5 = d_rewriter->postRewrite(n5);
+ TS_ASSERT(response5.d_node == B
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_disjoint A B)) = emptybag
+ Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == emptyBag
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_disjoint B A)) = emptybag
+ Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == emptyBag
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_max A B)) = emptybag
+ Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == emptyBag
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_max B A)) = emptybag
+ Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == emptyBag
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (intersection_min A B) A) = emptybag
+ Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == emptyBag
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (intersection_min B A) A) = emptybag
+ Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == emptyBag
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testDifferenceRemove()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (difference_remove A emptybag) = A
+ Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A A) = emptybag
+ Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == emptyBag
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_disjoint A B)) = emptybag
+ Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == emptyBag
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_disjoint B A)) = emptybag
+ Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == emptyBag
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_max A B)) = emptybag
+ Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == emptyBag
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_max B A)) = emptybag
+ Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == emptyBag
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove (intersection_min A B) A) = emptybag
+ Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == emptyBag
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove (intersection_min B A) A) = emptybag
+ Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == emptyBag
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testChoose()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node c = d_nm->mkConst(Rational(3));
+ Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+
+ // (bag.choose (mkBag x c)) = x where c is a constant > 0
+ Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == x
+ && response1.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testBagCard()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node zero = d_nm->mkConst(Rational(0));
+ Node c = d_nm->mkConst(Rational(3));
+ Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+ vector<Node> elements = getNStrings(2);
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(5)));
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+ // TODO(projects#223): enable this test after implementing bags normal form
+ // // (bag.card emptybag) = 0
+ // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
+ // RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ // TS_ASSERT(response1.d_node == zero && response1.d_status ==
+ // REWRITE_AGAIN_FULL);
+
+ // (bag.card (mkBag x c)) = c where c is a constant > 0
+ Node n2 = d_nm->mkNode(BAG_CARD, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == c
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+ Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
+ Node cardA = d_nm->mkNode(BAG_CARD, A);
+ Node cardB = d_nm->mkNode(BAG_CARD, B);
+ Node plus = d_nm->mkNode(PLUS, cardA, cardB);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == plus
+ && response3.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testIsSingleton()
+ {
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node c = d_nm->mkSkolem("c", d_nm->integerType());
+ Node bag = d_nm->mkBag(d_nm->stringType(), x, c);
+
+ // TODO(projects#223): complete this function
+ // (bag.is_singleton emptybag) = false
+ // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
+ // RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ // TS_ASSERT(response1.d_node == d_nm->mkConst(false)
+ // && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (bag.is_singleton (mkBag x c) = (c == 1)
+ Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ Node one = d_nm->mkConst(Rational(1));
+ Node equal = c.eqNode(one);
+ TS_ASSERT(response2.d_node == equal
+ && response2.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testFromSet()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
+
+ // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
+ Node n = d_nm->mkNode(BAG_FROM_SET, singleton);
+ RewriteResponse response = d_rewriter->postRewrite(n);
+ Node one = d_nm->mkConst(Rational(1));
+ Node bag = d_nm->mkBag(d_nm->stringType(), x, one);
+ TS_ASSERT(response.d_node == bag
+ && response.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testToSet()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node bag = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(5)));
+
+ // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
+ Node n = d_nm->mkNode(BAG_TO_SET, bag);
+ RewriteResponse response = d_rewriter->postRewrite(n);
+ Node singleton = d_nm->mkSingleton(d_nm->stringType(), x);
+ TS_ASSERT(response.d_node == singleton
+ && response.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ private:
+ std::unique_ptr<ExprManager> d_em;
+ std::unique_ptr<SmtEngine> d_smt;
+ std::unique_ptr<NodeManager> d_nm;
+
+ std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */
+++ /dev/null
-/********************* */
-/*! \file theory_bags_type_rules_black.h
- ** \verbatim
- ** Top contributors (to current version):
- ** Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Black box testing of bags typing rules
- **/
-
-#include <cxxtest/TestSuite.h>
-
-#include "expr/dtype.h"
-#include "smt/smt_engine.h"
-#include "theory/bags/theory_bags_type_rules.h"
-#include "theory/strings/type_enumerator.h"
-
-using namespace CVC4;
-using namespace CVC4::smt;
-using namespace CVC4::theory;
-using namespace CVC4::kind;
-using namespace CVC4::theory::bags;
-using namespace std;
-
-typedef expr::Attribute<Node, Node> attribute;
-
-class BagsTypeRuleBlack : public CxxTest::TestSuite
-{
- public:
- void setUp() override
- {
- d_em.reset(new ExprManager());
- d_smt.reset(new SmtEngine(d_em.get()));
- d_nm.reset(NodeManager::fromExprManager(d_em.get()));
- d_smt->finishInit();
- }
-
- void tearDown() override
- {
- d_smt.reset();
- d_nm.release();
- d_em.reset();
- }
-
- std::vector<Node> getNStrings(size_t n)
- {
- std::vector<Node> elements(n);
- CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
-
- for (size_t i = 0; i < n; i++)
- {
- ++enumerator;
- elements[i] = *enumerator;
- }
-
- return elements;
- }
-
- void testCountOperator()
- {
- vector<Node> elements = getNStrings(1);
- Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(100)));
-
- Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag);
- Node node = d_nm->mkConst(Rational(10));
-
- // node of type Int is not compatible with bag of type (Bag String)
- TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true),
- TypeCheckingExceptionPrivate&);
- }
-
- void testMkBagOperator()
- {
- vector<Node> elements = getNStrings(1);
- Node negative =
- d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
- Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
- Node positive =
- d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
-
- // only positive multiplicity are constants
- TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative));
- TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero));
- TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive));
- }
-
- void testFromSetOperator()
- {
- vector<Node> elements = getNStrings(1);
- Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]);
- TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set));
- TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag());
- }
-
- void testToSetOperator()
- {
- vector<Node> elements = getNStrings(1);
- Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(10)));
- TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
- TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
- }
-
- private:
- std::unique_ptr<ExprManager> d_em;
- std::unique_ptr<SmtEngine> d_smt;
- std::unique_ptr<NodeManager> d_nm;
-}; /* class BagsTypeRuleBlack */
--- /dev/null
+/********************* */
+/*! \file theory_bags_type_rules_black.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Black box testing of bags typing rules
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/theory_bags_type_rules.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleWhite : public CxxTest::TestSuite
+{
+ public:
+ void setUp() override
+ {
+ d_em.reset(new ExprManager());
+ d_smt.reset(new SmtEngine(d_em.get()));
+ d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+ d_smt->finishInit();
+ }
+
+ void tearDown() override
+ {
+ d_smt.reset();
+ d_nm.release();
+ d_em.reset();
+ }
+
+ std::vector<Node> getNStrings(size_t n)
+ {
+ std::vector<Node> elements(n);
+ CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
+
+ for (size_t i = 0; i < n; i++)
+ {
+ ++enumerator;
+ elements[i] = *enumerator;
+ }
+
+ return elements;
+ }
+
+ void testCountOperator()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node bag = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(100)));
+
+ Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag);
+ Node node = d_nm->mkConst(Rational(10));
+
+ // node of type Int is not compatible with bag of type (Bag String)
+ TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true),
+ TypeCheckingExceptionPrivate&);
+ }
+
+ void testMkBagOperator()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node negative = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+ Node zero = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+ Node positive = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+
+ // only positive multiplicity are constants
+ TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative));
+ TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero));
+ TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive));
+ }
+
+ void testFromSetOperator()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]);
+ TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set));
+ TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag());
+ }
+
+ void testToSetOperator()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10)));
+ TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
+ TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
+ }
+
+ private:
+ std::unique_ptr<ExprManager> d_em;
+ std::unique_ptr<SmtEngine> d_smt;
+ std::unique_ptr<NodeManager> d_nm;
+}; /* class BagsTypeRuleBlack */