From: mudathirmahgoub Date: Wed, 21 Oct 2020 13:19:55 +0000 (-0500) Subject: Add operator MakeBagOp for constructing bags (#5209) X-Git-Tag: cvc5-1.0.0~2677 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=098cee0ea412e24e24caa79307e2950a640279af;p=cvc5.git Add operator MakeBagOp for constructing bags (#5209) This PR removes subtyping rules for bags and add operator MakeBagOp similar to SingletonOp --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4d96fa0b3..5966debc1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -429,6 +429,8 @@ libcvc4_add_sources( theory/bags/bags_statistics.h theory/bags/inference_manager.cpp theory/bags/inference_manager.h + theory/bags/make_bag_op.cpp + theory/bags/make_bag_op.h theory/bags/normal_form.cpp theory/bags/normal_form.h theory/bags/rewrites.cpp diff --git a/src/api/cvc4cppkind.h b/src/api/cvc4cppkind.h index 913a4a993..d6ee24f1e 100644 --- a/src/api/cvc4cppkind.h +++ b/src/api/cvc4cppkind.h @@ -1841,7 +1841,8 @@ enum CVC4_PUBLIC Kind : int32_t */ MEMBER, /** - * The set of the single element given as a parameter. + * Construct a singleton set from an element given as a parameter. + * The returned set has same type of the element. * Parameters: 1 * -[1]: Single element * Create with: diff --git a/src/expr/node_manager.cpp b/src/expr/node_manager.cpp index f8057006c..e9f121047 100644 --- a/src/expr/node_manager.cpp +++ b/src/expr/node_manager.cpp @@ -961,6 +961,17 @@ Node NodeManager::mkSingleton(const TypeNode& t, const TNode n) return singleton; } +Node NodeManager::mkBag(const TypeNode& t, const TNode n, const TNode m) +{ + Assert(n.getType().isSubtypeOf(t)) + << "Invalid operands for mkBag. The type '" << n.getType() + << "' of node '" << n << "' is not a subtype of '" << t << "'." + << std::endl; + Node op = mkConst(MakeBagOp(t)); + Node bag = mkNode(kind::MK_BAG, op, n, m); + return bag; +} + Node NodeManager::mkAbstractValue(const TypeNode& type) { Node n = mkConst(AbstractValue(++d_abstractValueCount)); n.setAttribute(TypeAttr(), type); diff --git a/src/expr/node_manager.h b/src/expr/node_manager.h index 5427c3b6a..8f2237523 100644 --- a/src/expr/node_manager.h +++ b/src/expr/node_manager.h @@ -578,12 +578,22 @@ class NodeManager { /** * Create a singleton set from the given element n. * @param t the element type of the returned set. - * Note that the type of n needs to be a subtype of t. + * Note that the type of n needs to be a subtype of t. * @param n the single element in the singleton. * @return a singleton set constructed from the element n. */ Node mkSingleton(const TypeNode& t, const TNode n); + /** + * Create a bag from the given element n along with its multiplicity m. + * @param t the element type of the returned bag. + * Note that the type of n needs to be a subtype of t. + * @param n the element that is used to to construct the bag + * @param m the multiplicity of the element n + * @return a bag that contains m occurrences of n. + */ + Node mkBag(const TypeNode& t, const TNode n, const TNode m); + /** * Create a constant of type T. It will have the appropriate * CONST_* kind defined for T. diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 659b1eef2..e917a9d0d 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -574,11 +574,12 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) { case kind::ARRAY_TYPE: case kind::DATATYPE_TYPE: case kind::PARAMETRIC_DATATYPE: - case kind::SEQUENCE_TYPE: return TypeNode(); + case kind::SEQUENCE_TYPE: case kind::SET_TYPE: + case kind::BAG_TYPE: { - // we don't support subtyping for sets - return TypeNode(); // return null type + // we don't support subtyping except for built in types Int and Real. + return TypeNode(); // return null type } case kind::SEXPR_TYPE: Unimplemented() diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 1faaf55c0..c413a5e7e 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -438,7 +438,8 @@ BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const { // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1) Node one = d_nm->mkConst(Rational(1)); - Node bag = d_nm->mkNode(MK_BAG, n[0][0], one); + TypeNode type = n[0].getType().getSetElementType(); + Node bag = d_nm->mkBag(type, n[0][0], one); return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON); } return BagsRewriteResponse(n, Rewrite::NONE); diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 72326de08..86e89e0bd 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -47,7 +47,14 @@ operator DIFFERENCE_REMOVE 2 "bag difference remove (removes shared elements)" operator BAG_IS_INCLUDED 2 "inclusion predicate for bags (less than or equal multiplicities)" operator BAG_COUNT 2 "multiplicity of an element in a bag" -operator MK_BAG 2 "constructs a bag from one element along with its multiplicity" + +constant MK_BAG_OP \ + ::CVC4::MakeBagOp \ + ::CVC4::MakeBagOpHashFunction \ + "theory/bags/make_bag_op.h" \ + "operator for MK_BAG; payload is an instance of the CVC4::MakeBagOp class" +parameterized MK_BAG MK_BAG_OP 2 \ +"constructs a bag from one element along with its multiplicity" # The operator bag-is-singleton returns whether the given bag is a singleton operator BAG_IS_SINGLETON 1 "return whether the given bag is a singleton" @@ -69,6 +76,7 @@ typerule DIFFERENCE_SUBTRACT ::CVC4::theory::bags::BinaryOperatorTypeRule typerule DIFFERENCE_REMOVE ::CVC4::theory::bags::BinaryOperatorTypeRule typerule BAG_IS_INCLUDED ::CVC4::theory::bags::IsIncludedTypeRule typerule BAG_COUNT ::CVC4::theory::bags::CountTypeRule +typerule MK_BAG_OP "SimpleTypeRule" typerule MK_BAG ::CVC4::theory::bags::MkBagTypeRule typerule EMPTYBAG ::CVC4::theory::bags::EmptyBagTypeRule typerule BAG_CARD ::CVC4::theory::bags::CardTypeRule diff --git a/src/theory/bags/make_bag_op.cpp b/src/theory/bags/make_bag_op.cpp new file mode 100644 index 000000000..6a535afc2 --- /dev/null +++ b/src/theory/bags/make_bag_op.cpp @@ -0,0 +1,48 @@ +/********************* */ +/*! \file bag_op.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief a class for MK_BAG operator + **/ + +#include + +#include "expr/type_node.h" +#include "make_bag_op.h" + +namespace CVC4 { + +std::ostream& operator<<(std::ostream& out, const MakeBagOp& op) +{ + return out << "(mkBag_op " << op.getType() << ')'; +} + +size_t MakeBagOpHashFunction::operator()(const MakeBagOp& op) const +{ + return TypeNodeHashFunction()(op.getType()); +} + +MakeBagOp::MakeBagOp(const TypeNode& elementType) + : d_type(new TypeNode(elementType)) +{ +} + +MakeBagOp::MakeBagOp(const MakeBagOp& op) : d_type(new TypeNode(op.getType())) +{ +} + +const TypeNode& MakeBagOp::getType() const { return *d_type; } + +bool MakeBagOp::operator==(const MakeBagOp& op) const +{ + return getType() == op.getType(); +} + +} // namespace CVC4 diff --git a/src/theory/bags/make_bag_op.h b/src/theory/bags/make_bag_op.h new file mode 100644 index 000000000..b47930879 --- /dev/null +++ b/src/theory/bags/make_bag_op.h @@ -0,0 +1,63 @@ +/********************* */ +/*! \file mk_bag_op.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief a class for MK_BAG operator + **/ + +#include "cvc4_public.h" + +#ifndef CVC4__MAKE_BAG_OP_H +#define CVC4__MAKE_BAG_OP_H + +#include + +namespace CVC4 { + +class TypeNode; + +/** + * The class is an operator for kind MK_BAG used to construct bags. + * It specifies the type of the element especially when it is a constant. + * e.g. the type of rational 1 is Int, however + * (mkBag (mkBag_op Real) 1) is of type (Bag Real), not (Bag Int). + * Note that the type passed to the constructor is the element's type, not the + * bag type. + */ +class MakeBagOp +{ + public: + MakeBagOp(const TypeNode& elementType); + MakeBagOp(const MakeBagOp& op); + + /** return the type of the current object */ + const TypeNode& getType() const; + + bool operator==(const MakeBagOp& op) const; + + private: + MakeBagOp(); + /** a pointer to the type of the bag element */ + std::unique_ptr d_type; +}; /* class MakeBagOp */ + +std::ostream& operator<<(std::ostream& out, const MakeBagOp& op); + +/** + * Hash function for the MakeBagOpHashFunction objects. + */ +struct CVC4_PUBLIC MakeBagOpHashFunction +{ + size_t operator()(const MakeBagOp& op) const; +}; /* struct MakeBagOpHashFunction */ + +} // namespace CVC4 + +#endif /* CVC4__MAKE_BAG_OP_H */ diff --git a/src/theory/bags/theory_bags_type_enumerator.cpp b/src/theory/bags/theory_bags_type_enumerator.cpp index 7975bb379..727407937 100644 --- a/src/theory/bags/theory_bags_type_enumerator.cpp +++ b/src/theory/bags/theory_bags_type_enumerator.cpp @@ -54,7 +54,8 @@ BagEnumerator& BagEnumerator::operator++() { // increase the multiplicity by one Node one = d_nodeManager->mkConst(Rational(1)); - Node singleton = d_nodeManager->mkNode(kind::MK_BAG, d_element, one); + TypeNode elementType = d_elementTypeEnumerator.getType(); + Node singleton = d_nodeManager->mkBag(elementType, d_element, one); if (d_currentBag.getKind() == kind::EMPTYBAG) { d_currentBag = singleton; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 67293e222..75f57ec88 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -42,19 +42,11 @@ struct BinaryOperatorTypeRule TypeNode secondBagType = n[1].getType(check); if (secondBagType != bagType) { - if (n.getKind() == kind::INTERSECTION_MIN) - { - bagType = TypeNode::mostCommonTypeNode(secondBagType, bagType); - } - else - { - bagType = TypeNode::leastCommonTypeNode(secondBagType, bagType); - } - if (bagType.isNull()) - { - throw TypeCheckingExceptionPrivate( - n, "operator expects two bags of comparable types"); - } + std::stringstream ss; + ss << "Operator " << n.getKind() + << " expects two bags of the same type. Found types '" << bagType + << "' and '" << secondBagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); } } return bagType; @@ -110,15 +102,9 @@ struct CountTypeRule n, "checking for membership in a non-bag"); } TypeNode elementType = n[0].getType(check); - // TODO(projects#226): comments from sets - // - // T : (Bag Int) - // B : (Bag Real) - // (= (as T (Bag Real)) B) - // (= (bag-count 0.5 B) 1) - // ...where (bag-count 0.5 T) is inferred - - if (!elementType.isComparableTo(bagType.getBagElementType())) + // e.g. (count 1 (mkBag (mkBag_op Real) 1.0 3))) is 3 whereas + // (count 1.0 (mkBag (mkBag_op Int) 1 3))) throws a typing error + if (!elementType.isSubtypeOf(bagType.getBagElementType())) { std::stringstream ss; ss << "member operating on bags of different types:\n" @@ -136,7 +122,10 @@ struct MkBagTypeRule { static TypeNode computeType(NodeManager* nm, TNode n, bool check) { - Assert(n.getKind() == kind::MK_BAG); + Assert(n.getKind() == kind::MK_BAG && n.hasOperator() + && n.getOperator().getKind() == kind::MK_BAG_OP); + MakeBagOp op = n.getOperator().getConst(); + TypeNode expectedElementType = op.getType(); if (check) { if (n.getNumChildren() != 2) @@ -153,9 +142,21 @@ struct MkBagTypeRule ss << "MK_BAG expects an integer for " << n[1] << ". Found" << type1; throw TypeCheckingExceptionPrivate(n, ss.str()); } + + TypeNode actualElementType = n[0].getType(check); + // the type of the element should be a subtype of the type of the operator + // e.g. (mkBag (mkBag_op Real) 1 1) where 1 is an Int + if (!actualElementType.isSubtypeOf(expectedElementType)) + { + std::stringstream ss; + ss << "The type '" << actualElementType + << "' of the element is not a subtype of '" << expectedElementType + << "' in term : " << n; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } } - return nm->mkBagType(n[0].getType(check)); + return nm->mkBagType(expectedElementType); } static bool computeIsConst(NodeManager* nodeManager, TNode n) diff --git a/test/unit/theory/CMakeLists.txt b/test/unit/theory/CMakeLists.txt index e541a24fb..481c80f26 100644 --- a/test/unit/theory/CMakeLists.txt +++ b/test/unit/theory/CMakeLists.txt @@ -14,8 +14,8 @@ cvc4_add_unit_test_white(evaluator_white theory) cvc4_add_unit_test_white(logic_info_white theory) cvc4_add_unit_test_white(sequences_rewriter_white theory) cvc4_add_unit_test_white(theory_arith_white theory) -cvc4_add_unit_test_white(theory_bags_rewriter_black theory) -cvc4_add_unit_test_white(theory_bags_type_rules_black theory) +cvc4_add_unit_test_white(theory_bags_rewriter_white theory) +cvc4_add_unit_test_white(theory_bags_type_rules_white theory) cvc4_add_unit_test_white(theory_bv_rewriter_white theory) cvc4_add_unit_test_white(theory_bv_white theory) cvc4_add_unit_test_white(theory_engine_white theory) diff --git a/test/unit/theory/theory_bags_rewriter_black.h b/test/unit/theory/theory_bags_rewriter_black.h deleted file mode 100644 index 98f56fd44..000000000 --- a/test/unit/theory/theory_bags_rewriter_black.h +++ /dev/null @@ -1,620 +0,0 @@ -/********************* */ -/*! \file theory_bags_rewriter_black.h - ** \verbatim - ** Top contributors (to current version): - ** Mudathir Mohamed - ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. - ** All rights reserved. See the file COPYING in the top-level source - ** directory for licensing information.\endverbatim - ** - ** \brief Black box testing of bags rewriter - **/ - -#include - -#include "expr/dtype.h" -#include "smt/smt_engine.h" -#include "theory/bags/bags_rewriter.h" -#include "theory/strings/type_enumerator.h" - -using namespace CVC4; -using namespace CVC4::smt; -using namespace CVC4::theory; -using namespace CVC4::kind; -using namespace CVC4::theory::bags; -using namespace std; - -typedef expr::Attribute attribute; - -class BagsTypeRuleBlack : public CxxTest::TestSuite -{ - public: - void setUp() override - { - d_em.reset(new ExprManager()); - d_smt.reset(new SmtEngine(d_em.get())); - d_nm.reset(NodeManager::fromExprManager(d_em.get())); - d_smt->finishInit(); - d_rewriter.reset(new BagsRewriter(nullptr)); - } - - void tearDown() override - { - d_rewriter.reset(); - d_smt.reset(); - d_nm.release(); - d_em.reset(); - } - - std::vector getNStrings(size_t n) - { - std::vector elements(n); - for (size_t i = 0; i < n; i++) - { - elements[i] = d_nm->mkSkolem("x", d_nm->stringType()); - } - return elements; - } - - void testEmptyBagNormalForm() - { - Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType())); - // empty bags are in normal form - TS_ASSERT(emptybag.isConst()); - RewriteResponse response = d_rewriter->postRewrite(emptybag); - TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE); - } - - void testBagEquality() - { - vector elements = getNStrings(2); - Node x = elements[0]; - Node y = elements[1]; - Node c = d_nm->mkSkolem("c", d_nm->integerType()); - Node d = d_nm->mkSkolem("d", d_nm->integerType()); - Node bagX = d_nm->mkNode(MK_BAG, x, c); - Node bagY = d_nm->mkNode(MK_BAG, y, d); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - - // (= A A) = true where A is a bag - Node n1 = emptyBag.eqNode(emptyBag); - RewriteResponse response1 = d_rewriter->preRewrite(n1); - TS_ASSERT(response1.d_node == d_nm->mkConst(true) - && response1.d_status == REWRITE_AGAIN_FULL); - } - - void testMkBagConstantElement() - { - vector elements = getNStrings(1); - Node negative = - d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1))); - Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0))); - Node positive = - d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1))); - Node emptybag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); - RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); - RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); - - // bags with non-positive multiplicity are rewritten as empty bags - TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL - && negativeResponse.d_node == emptybag); - TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL - && zeroResponse.d_node == emptybag); - - // no change for positive - TS_ASSERT(positiveResponse.d_status == REWRITE_DONE - && positive == positiveResponse.d_node); - } - - void testMkBagVariableElement() - { - Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); - Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1))); - Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1))); - Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0))); - Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1))); - Node emptybag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); - RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); - RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); - - // bags with non-positive multiplicity are rewritten as empty bags - TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL - && negativeResponse.d_node == emptybag); - TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL - && zeroResponse.d_node == emptybag); - - // no change for positive - TS_ASSERT(positiveResponse.d_status == REWRITE_DONE - && positive == positiveResponse.d_node); - } - - void testBagCount() - { - int n = 3; - Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); - Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType()))); - Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n))); - - // (bag.count x emptybag) = 0 - Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(n1); - TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL - && response1.d_node == d_nm->mkConst(Rational(0))); - - // (bag.count x (mkBag x c) = c where c > 0 is a constant - Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL - && response2.d_node == d_nm->mkConst(Rational(n))); - } - - void testUnionMax() - { - int n = 3; - vector elements = getNStrings(2); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); - Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); - Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); - - // (union_max A emptybag) = A - Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(unionMax1); - TS_ASSERT(response1.d_node == A - && response1.d_status == REWRITE_AGAIN_FULL); - - // (union_max emptybag A) = A - Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A); - RewriteResponse response2 = d_rewriter->postRewrite(unionMax2); - TS_ASSERT(response2.d_node == A - && response2.d_status == REWRITE_AGAIN_FULL); - - // (union_max A A) = A - Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A); - RewriteResponse response3 = d_rewriter->postRewrite(unionMax3); - TS_ASSERT(response3.d_node == A - && response3.d_status == REWRITE_AGAIN_FULL); - - // (union_max A (union_max A B)) = (union_max A B) - Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB); - RewriteResponse response4 = d_rewriter->postRewrite(unionMax4); - TS_ASSERT(response4.d_node == unionMaxAB - && response4.d_status == REWRITE_AGAIN_FULL); - - // (union_max A (union_max B A)) = (union_max B A) - Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA); - RewriteResponse response5 = d_rewriter->postRewrite(unionMax5); - TS_ASSERT(response5.d_node == unionMaxBA - && response4.d_status == REWRITE_AGAIN_FULL); - - // (union_max (union_max A B) A) = (union_max A B) - Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A); - RewriteResponse response6 = d_rewriter->postRewrite(unionMax6); - TS_ASSERT(response6.d_node == unionMaxAB - && response6.d_status == REWRITE_AGAIN_FULL); - - // (union_max (union_max B A) A) = (union_max B A) - Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A); - RewriteResponse response7 = d_rewriter->postRewrite(unionMax7); - TS_ASSERT(response7.d_node == unionMaxBA - && response7.d_status == REWRITE_AGAIN_FULL); - - // (union_max A (union_disjoint A B)) = (union_disjoint A B) - Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB); - RewriteResponse response8 = d_rewriter->postRewrite(unionMax8); - TS_ASSERT(response8.d_node == unionDisjointAB - && response8.d_status == REWRITE_AGAIN_FULL); - - // (union_max A (union_disjoint B A)) = (union_disjoint B A) - Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA); - RewriteResponse response9 = d_rewriter->postRewrite(unionMax9); - TS_ASSERT(response9.d_node == unionDisjointBA - && response9.d_status == REWRITE_AGAIN_FULL); - - // (union_max (union_disjoint A B) A) = (union_disjoint A B) - Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A); - RewriteResponse response10 = d_rewriter->postRewrite(unionMax10); - TS_ASSERT(response10.d_node == unionDisjointAB - && response10.d_status == REWRITE_AGAIN_FULL); - - // (union_max (union_disjoint B A) A) = (union_disjoint B A) - Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A); - RewriteResponse response11 = d_rewriter->postRewrite(unionMax11); - TS_ASSERT(response11.d_node == unionDisjointBA - && response11.d_status == REWRITE_AGAIN_FULL); - } - - void testUnionDisjoint() - { - int n = 3; - vector elements = getNStrings(2); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); - Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); - Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); - Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); - Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); - - // (union_disjoint A emptybag) = A - Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1); - TS_ASSERT(response1.d_node == A - && response1.d_status == REWRITE_AGAIN_FULL); - - // (union_disjoint emptybag A) = A - Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A); - RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2); - TS_ASSERT(response2.d_node == A - && response2.d_status == REWRITE_AGAIN_FULL); - - // (union_disjoint (union_max A B) (intersection_min B A)) = - // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) - Node unionDisjoint3 = - d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA); - RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3); - TS_ASSERT(response3.d_node == unionDisjointAB - && response3.d_status == REWRITE_AGAIN_FULL); - - // (union_disjoint (intersection_min B A)) (union_max A B) = - // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b) - Node unionDisjoint4 = - d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA); - RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4); - TS_ASSERT(response4.d_node == unionDisjointBA - && response4.d_status == REWRITE_AGAIN_FULL); - } - - void testIntersectionMin() - { - int n = 3; - vector elements = getNStrings(2); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); - Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); - Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); - - // (intersection_min A emptybag) = emptyBag - Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(n1); - TS_ASSERT(response1.d_node == emptyBag - && response1.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min emptybag A) = emptyBag - Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - TS_ASSERT(response2.d_node == emptyBag - && response2.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min A A) = A - Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A); - RewriteResponse response3 = d_rewriter->postRewrite(n3); - TS_ASSERT(response3.d_node == A - && response3.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min A (union_max A B) = A - Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB); - RewriteResponse response4 = d_rewriter->postRewrite(n4); - TS_ASSERT(response4.d_node == A - && response4.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min A (union_max B A) = A - Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA); - RewriteResponse response5 = d_rewriter->postRewrite(n5); - TS_ASSERT(response5.d_node == A - && response4.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min (union_max A B) A) = A - Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A); - RewriteResponse response6 = d_rewriter->postRewrite(n6); - TS_ASSERT(response6.d_node == A - && response6.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min (union_max B A) A) = A - Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A); - RewriteResponse response7 = d_rewriter->postRewrite(n7); - TS_ASSERT(response7.d_node == A - && response7.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min A (union_disjoint A B) = A - Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB); - RewriteResponse response8 = d_rewriter->postRewrite(n8); - TS_ASSERT(response8.d_node == A - && response8.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min A (union_disjoint B A) = A - Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA); - RewriteResponse response9 = d_rewriter->postRewrite(n9); - TS_ASSERT(response9.d_node == A - && response9.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min (union_disjoint A B) A) = A - Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A); - RewriteResponse response10 = d_rewriter->postRewrite(n10); - TS_ASSERT(response10.d_node == A - && response10.d_status == REWRITE_AGAIN_FULL); - - // (intersection_min (union_disjoint B A) A) = A - Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A); - RewriteResponse response11 = d_rewriter->postRewrite(n11); - TS_ASSERT(response11.d_node == A - && response11.d_status == REWRITE_AGAIN_FULL); - } - - void testDifferenceSubtract() - { - int n = 3; - vector elements = getNStrings(2); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); - Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); - Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); - Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); - Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); - - // (difference_subtract A emptybag) = A - Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(n1); - TS_ASSERT(response1.d_node == A - && response1.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract emptybag A) = emptyBag - Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - TS_ASSERT(response2.d_node == emptyBag - && response2.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract A A) = emptybag - Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A); - RewriteResponse response3 = d_rewriter->postRewrite(n3); - TS_ASSERT(response3.d_node == emptyBag - && response3.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract (union_disjoint A B) A) = B - Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A); - RewriteResponse response4 = d_rewriter->postRewrite(n4); - TS_ASSERT(response4.d_node == B - && response4.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract (union_disjoint B A) A) = B - Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A); - RewriteResponse response5 = d_rewriter->postRewrite(n5); - TS_ASSERT(response5.d_node == B - && response4.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract A (union_disjoint A B)) = emptybag - Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB); - RewriteResponse response6 = d_rewriter->postRewrite(n6); - TS_ASSERT(response6.d_node == emptyBag - && response6.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract A (union_disjoint B A)) = emptybag - Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA); - RewriteResponse response7 = d_rewriter->postRewrite(n7); - TS_ASSERT(response7.d_node == emptyBag - && response7.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract A (union_max A B)) = emptybag - Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB); - RewriteResponse response8 = d_rewriter->postRewrite(n8); - TS_ASSERT(response8.d_node == emptyBag - && response8.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract A (union_max B A)) = emptybag - Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA); - RewriteResponse response9 = d_rewriter->postRewrite(n9); - TS_ASSERT(response9.d_node == emptyBag - && response9.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract (intersection_min A B) A) = emptybag - Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A); - RewriteResponse response10 = d_rewriter->postRewrite(n10); - TS_ASSERT(response10.d_node == emptyBag - && response10.d_status == REWRITE_AGAIN_FULL); - - // (difference_subtract (intersection_min B A) A) = emptybag - Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A); - RewriteResponse response11 = d_rewriter->postRewrite(n11); - TS_ASSERT(response11.d_node == emptyBag - && response11.d_status == REWRITE_AGAIN_FULL); - } - - void testDifferenceRemove() - { - int n = 3; - vector elements = getNStrings(2); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); - Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); - Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); - Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); - Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); - - // (difference_remove A emptybag) = A - Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag); - RewriteResponse response1 = d_rewriter->postRewrite(n1); - TS_ASSERT(response1.d_node == A - && response1.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove emptybag A) = emptyBag - Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - TS_ASSERT(response2.d_node == emptyBag - && response2.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove A A) = emptybag - Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A); - RewriteResponse response3 = d_rewriter->postRewrite(n3); - TS_ASSERT(response3.d_node == emptyBag - && response3.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove A (union_disjoint A B)) = emptybag - Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB); - RewriteResponse response6 = d_rewriter->postRewrite(n6); - TS_ASSERT(response6.d_node == emptyBag - && response6.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove A (union_disjoint B A)) = emptybag - Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA); - RewriteResponse response7 = d_rewriter->postRewrite(n7); - TS_ASSERT(response7.d_node == emptyBag - && response7.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove A (union_max A B)) = emptybag - Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB); - RewriteResponse response8 = d_rewriter->postRewrite(n8); - TS_ASSERT(response8.d_node == emptyBag - && response8.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove A (union_max B A)) = emptybag - Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA); - RewriteResponse response9 = d_rewriter->postRewrite(n9); - TS_ASSERT(response9.d_node == emptyBag - && response9.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove (intersection_min A B) A) = emptybag - Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A); - RewriteResponse response10 = d_rewriter->postRewrite(n10); - TS_ASSERT(response10.d_node == emptyBag - && response10.d_status == REWRITE_AGAIN_FULL); - - // (difference_remove (intersection_min B A) A) = emptybag - Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A); - RewriteResponse response11 = d_rewriter->postRewrite(n11); - TS_ASSERT(response11.d_node == emptyBag - && response11.d_status == REWRITE_AGAIN_FULL); - } - - void testChoose() - { - Node x = d_nm->mkSkolem("x", d_nm->stringType()); - Node c = d_nm->mkConst(Rational(3)); - Node bag = d_nm->mkNode(MK_BAG, x, c); - - // (bag.choose (mkBag x c)) = x where c is a constant > 0 - Node n1 = d_nm->mkNode(BAG_CHOOSE, bag); - RewriteResponse response1 = d_rewriter->postRewrite(n1); - TS_ASSERT(response1.d_node == x - && response1.d_status == REWRITE_AGAIN_FULL); - } - - void testBagCard() - { - Node x = d_nm->mkSkolem("x", d_nm->stringType()); - Node emptyBag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node zero = d_nm->mkConst(Rational(0)); - Node c = d_nm->mkConst(Rational(3)); - Node bag = d_nm->mkNode(MK_BAG, x, c); - vector elements = getNStrings(2); - Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4))); - Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5))); - Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); - - // TODO(projects#223): enable this test after implementing bags normal form - // // (bag.card emptybag) = 0 - // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag); - // RewriteResponse response1 = d_rewriter->postRewrite(n1); - // TS_ASSERT(response1.d_node == zero && response1.d_status == - // REWRITE_AGAIN_FULL); - - // (bag.card (mkBag x c)) = c where c is a constant > 0 - Node n2 = d_nm->mkNode(BAG_CARD, bag); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - TS_ASSERT(response2.d_node == c - && response2.d_status == REWRITE_AGAIN_FULL); - - // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) - Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB); - Node cardA = d_nm->mkNode(BAG_CARD, A); - Node cardB = d_nm->mkNode(BAG_CARD, B); - Node plus = d_nm->mkNode(PLUS, cardA, cardB); - RewriteResponse response3 = d_rewriter->postRewrite(n3); - TS_ASSERT(response3.d_node == plus - && response3.d_status == REWRITE_AGAIN_FULL); - } - - void testIsSingleton() - { - Node emptybag = - d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); - Node x = d_nm->mkSkolem("x", d_nm->stringType()); - Node c = d_nm->mkSkolem("c", d_nm->integerType()); - Node bag = d_nm->mkNode(MK_BAG, x, c); - - // TODO(projects#223): complete this function - // (bag.is_singleton emptybag) = false - // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag); - // RewriteResponse response1 = d_rewriter->postRewrite(n1); - // TS_ASSERT(response1.d_node == d_nm->mkConst(false) - // && response1.d_status == REWRITE_AGAIN_FULL); - - // (bag.is_singleton (mkBag x c) = (c == 1) - Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag); - RewriteResponse response2 = d_rewriter->postRewrite(n2); - Node one = d_nm->mkConst(Rational(1)); - Node equal = c.eqNode(one); - TS_ASSERT(response2.d_node == equal - && response2.d_status == REWRITE_AGAIN_FULL); - } - - void testFromSet() - { - Node x = d_nm->mkSkolem("x", d_nm->stringType()); - Node singleton = d_nm->mkSingleton(d_nm->stringType(), x); - - // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1) - Node n = d_nm->mkNode(BAG_FROM_SET, singleton); - RewriteResponse response = d_rewriter->postRewrite(n); - Node one = d_nm->mkConst(Rational(1)); - Node bag = d_nm->mkNode(MK_BAG, x, one); - TS_ASSERT(response.d_node == bag - && response.d_status == REWRITE_AGAIN_FULL); - } - - void testToSet() - { - Node x = d_nm->mkSkolem("x", d_nm->stringType()); - Node bag = d_nm->mkNode(MK_BAG, x, d_nm->mkConst(Rational(5))); - - // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x) - Node n = d_nm->mkNode(BAG_TO_SET, bag); - RewriteResponse response = d_rewriter->postRewrite(n); - Node singleton = d_nm->mkSingleton(d_nm->stringType(), x); - TS_ASSERT(response.d_node == singleton - && response.d_status == REWRITE_AGAIN_FULL); - } - - private: - std::unique_ptr d_em; - std::unique_ptr d_smt; - std::unique_ptr d_nm; - - std::unique_ptr d_rewriter; -}; /* class BagsTypeRuleBlack */ diff --git a/test/unit/theory/theory_bags_rewriter_white.h b/test/unit/theory/theory_bags_rewriter_white.h new file mode 100644 index 000000000..b1c75fdbd --- /dev/null +++ b/test/unit/theory/theory_bags_rewriter_white.h @@ -0,0 +1,638 @@ +/********************* */ +/*! \file theory_bags_rewriter_white.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief White box testing of bags rewriter + **/ + +#include + +#include "expr/dtype.h" +#include "smt/smt_engine.h" +#include "theory/bags/bags_rewriter.h" +#include "theory/strings/type_enumerator.h" + +using namespace CVC4; +using namespace CVC4::smt; +using namespace CVC4::theory; +using namespace CVC4::kind; +using namespace CVC4::theory::bags; +using namespace std; + +typedef expr::Attribute attribute; + +class BagsTypeRuleWhite : public CxxTest::TestSuite +{ + public: + void setUp() override + { + d_em.reset(new ExprManager()); + d_smt.reset(new SmtEngine(d_em.get())); + d_nm.reset(NodeManager::fromExprManager(d_em.get())); + d_smt->finishInit(); + d_rewriter.reset(new BagsRewriter(nullptr)); + } + + void tearDown() override + { + d_rewriter.reset(); + d_smt.reset(); + d_nm.release(); + d_em.reset(); + } + + std::vector getNStrings(size_t n) + { + std::vector elements(n); + for (size_t i = 0; i < n; i++) + { + elements[i] = d_nm->mkSkolem("x", d_nm->stringType()); + } + return elements; + } + + void testEmptyBagNormalForm() + { + Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType())); + // empty bags are in normal form + TS_ASSERT(emptybag.isConst()); + RewriteResponse response = d_rewriter->postRewrite(emptybag); + TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE); + } + + void testBagEquality() + { + vector elements = getNStrings(2); + Node x = elements[0]; + Node y = elements[1]; + Node c = d_nm->mkSkolem("c", d_nm->integerType()); + Node d = d_nm->mkSkolem("d", d_nm->integerType()); + Node bagX = d_nm->mkBag(d_nm->stringType(), x, c); + Node bagY = d_nm->mkBag(d_nm->stringType(), y, d); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + + // (= A A) = true where A is a bag + Node n1 = emptyBag.eqNode(emptyBag); + RewriteResponse response1 = d_rewriter->preRewrite(n1); + TS_ASSERT(response1.d_node == d_nm->mkConst(true) + && response1.d_status == REWRITE_AGAIN_FULL); + } + + void testMkBagConstantElement() + { + vector elements = getNStrings(1); + Node negative = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1))); + Node zero = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0))); + Node positive = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); + RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); + RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); + + // bags with non-positive multiplicity are rewritten as empty bags + TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL + && negativeResponse.d_node == emptybag); + TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL + && zeroResponse.d_node == emptybag); + + // no change for positive + TS_ASSERT(positiveResponse.d_status == REWRITE_DONE + && positive == positiveResponse.d_node); + } + + void testMkBagVariableElement() + { + Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); + Node variable = + d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1))); + Node negative = + d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(-1))); + Node zero = + d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(0))); + Node positive = + d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(1))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); + RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); + RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); + + // bags with non-positive multiplicity are rewritten as empty bags + TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL + && negativeResponse.d_node == emptybag); + TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL + && zeroResponse.d_node == emptybag); + + // no change for positive + TS_ASSERT(positiveResponse.d_status == REWRITE_DONE + && positive == positiveResponse.d_node); + } + + void testBagCount() + { + int n = 3; + Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); + Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType()))); + Node bag = + d_nm->mkBag(d_nm->stringType(), skolem, d_nm->mkConst(Rational(n))); + + // (bag.count x emptybag) = 0 + Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL + && response1.d_node == d_nm->mkConst(Rational(0))); + + // (bag.count x (mkBag x c) = c where c > 0 is a constant + Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL + && response2.d_node == d_nm->mkConst(Rational(n))); + } + + void testUnionMax() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + + // (union_max A emptybag) = A + Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(unionMax1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (union_max emptybag A) = A + Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(unionMax2); + TS_ASSERT(response2.d_node == A + && response2.d_status == REWRITE_AGAIN_FULL); + + // (union_max A A) = A + Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(unionMax3); + TS_ASSERT(response3.d_node == A + && response3.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_max A B)) = (union_max A B) + Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB); + RewriteResponse response4 = d_rewriter->postRewrite(unionMax4); + TS_ASSERT(response4.d_node == unionMaxAB + && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_max B A)) = (union_max B A) + Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA); + RewriteResponse response5 = d_rewriter->postRewrite(unionMax5); + TS_ASSERT(response5.d_node == unionMaxBA + && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_max A B) A) = (union_max A B) + Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A); + RewriteResponse response6 = d_rewriter->postRewrite(unionMax6); + TS_ASSERT(response6.d_node == unionMaxAB + && response6.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_max B A) A) = (union_max B A) + Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A); + RewriteResponse response7 = d_rewriter->postRewrite(unionMax7); + TS_ASSERT(response7.d_node == unionMaxBA + && response7.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_disjoint A B)) = (union_disjoint A B) + Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB); + RewriteResponse response8 = d_rewriter->postRewrite(unionMax8); + TS_ASSERT(response8.d_node == unionDisjointAB + && response8.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_disjoint B A)) = (union_disjoint B A) + Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA); + RewriteResponse response9 = d_rewriter->postRewrite(unionMax9); + TS_ASSERT(response9.d_node == unionDisjointBA + && response9.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_disjoint A B) A) = (union_disjoint A B) + Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(unionMax10); + TS_ASSERT(response10.d_node == unionDisjointAB + && response10.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_disjoint B A) A) = (union_disjoint B A) + Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(unionMax11); + TS_ASSERT(response11.d_node == unionDisjointBA + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testUnionDisjoint() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (union_disjoint A emptybag) = A + Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint emptybag A) = A + Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2); + TS_ASSERT(response2.d_node == A + && response2.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (union_max A B) (intersection_min B A)) = + // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint3 = + d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA); + RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3); + TS_ASSERT(response3.d_node == unionDisjointAB + && response3.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (intersection_min B A)) (union_max A B) = + // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint4 = + d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA); + RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4); + TS_ASSERT(response4.d_node == unionDisjointBA + && response4.d_status == REWRITE_AGAIN_FULL); + } + + void testIntersectionMin() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + + // (intersection_min A emptybag) = emptyBag + Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == emptyBag + && response1.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min emptybag A) = emptyBag + Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A A) = A + Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == A + && response3.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_max A B) = A + Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB); + RewriteResponse response4 = d_rewriter->postRewrite(n4); + TS_ASSERT(response4.d_node == A + && response4.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_max B A) = A + Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA); + RewriteResponse response5 = d_rewriter->postRewrite(n5); + TS_ASSERT(response5.d_node == A + && response4.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_max A B) A) = A + Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == A + && response6.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_max B A) A) = A + Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == A + && response7.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_disjoint A B) = A + Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == A + && response8.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_disjoint B A) = A + Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == A + && response9.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_disjoint A B) A) = A + Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == A + && response10.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_disjoint B A) A) = A + Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == A + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testDifferenceSubtract() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (difference_subtract A emptybag) = A + Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract emptybag A) = emptyBag + Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A A) = emptybag + Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == emptyBag + && response3.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (union_disjoint A B) A) = B + Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A); + RewriteResponse response4 = d_rewriter->postRewrite(n4); + TS_ASSERT(response4.d_node == B + && response4.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (union_disjoint B A) A) = B + Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A); + RewriteResponse response5 = d_rewriter->postRewrite(n5); + TS_ASSERT(response5.d_node == B + && response4.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_disjoint A B)) = emptybag + Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == emptyBag + && response6.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_disjoint B A)) = emptybag + Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == emptyBag + && response7.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_max A B)) = emptybag + Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == emptyBag + && response8.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_max B A)) = emptybag + Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == emptyBag + && response9.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (intersection_min A B) A) = emptybag + Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == emptyBag + && response10.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (intersection_min B A) A) = emptybag + Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == emptyBag + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testDifferenceRemove() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (difference_remove A emptybag) = A + Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove emptybag A) = emptyBag + Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A A) = emptybag + Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == emptyBag + && response3.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_disjoint A B)) = emptybag + Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == emptyBag + && response6.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_disjoint B A)) = emptybag + Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == emptyBag + && response7.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_max A B)) = emptybag + Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == emptyBag + && response8.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_max B A)) = emptybag + Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == emptyBag + && response9.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove (intersection_min A B) A) = emptybag + Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == emptyBag + && response10.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove (intersection_min B A) A) = emptybag + Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == emptyBag + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testChoose() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node c = d_nm->mkConst(Rational(3)); + Node bag = d_nm->mkBag(d_nm->stringType(), x, c); + + // (bag.choose (mkBag x c)) = x where c is a constant > 0 + Node n1 = d_nm->mkNode(BAG_CHOOSE, bag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == x + && response1.d_status == REWRITE_AGAIN_FULL); + } + + void testBagCard() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node zero = d_nm->mkConst(Rational(0)); + Node c = d_nm->mkConst(Rational(3)); + Node bag = d_nm->mkBag(d_nm->stringType(), x, c); + vector elements = getNStrings(2); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(5))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + + // TODO(projects#223): enable this test after implementing bags normal form + // // (bag.card emptybag) = 0 + // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag); + // RewriteResponse response1 = d_rewriter->postRewrite(n1); + // TS_ASSERT(response1.d_node == zero && response1.d_status == + // REWRITE_AGAIN_FULL); + + // (bag.card (mkBag x c)) = c where c is a constant > 0 + Node n2 = d_nm->mkNode(BAG_CARD, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == c + && response2.d_status == REWRITE_AGAIN_FULL); + + // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) + Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB); + Node cardA = d_nm->mkNode(BAG_CARD, A); + Node cardB = d_nm->mkNode(BAG_CARD, B); + Node plus = d_nm->mkNode(PLUS, cardA, cardB); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == plus + && response3.d_status == REWRITE_AGAIN_FULL); + } + + void testIsSingleton() + { + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node c = d_nm->mkSkolem("c", d_nm->integerType()); + Node bag = d_nm->mkBag(d_nm->stringType(), x, c); + + // TODO(projects#223): complete this function + // (bag.is_singleton emptybag) = false + // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag); + // RewriteResponse response1 = d_rewriter->postRewrite(n1); + // TS_ASSERT(response1.d_node == d_nm->mkConst(false) + // && response1.d_status == REWRITE_AGAIN_FULL); + + // (bag.is_singleton (mkBag x c) = (c == 1) + Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + Node one = d_nm->mkConst(Rational(1)); + Node equal = c.eqNode(one); + TS_ASSERT(response2.d_node == equal + && response2.d_status == REWRITE_AGAIN_FULL); + } + + void testFromSet() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node singleton = d_nm->mkSingleton(d_nm->stringType(), x); + + // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1) + Node n = d_nm->mkNode(BAG_FROM_SET, singleton); + RewriteResponse response = d_rewriter->postRewrite(n); + Node one = d_nm->mkConst(Rational(1)); + Node bag = d_nm->mkBag(d_nm->stringType(), x, one); + TS_ASSERT(response.d_node == bag + && response.d_status == REWRITE_AGAIN_FULL); + } + + void testToSet() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node bag = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(5))); + + // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x) + Node n = d_nm->mkNode(BAG_TO_SET, bag); + RewriteResponse response = d_rewriter->postRewrite(n); + Node singleton = d_nm->mkSingleton(d_nm->stringType(), x); + TS_ASSERT(response.d_node == singleton + && response.d_status == REWRITE_AGAIN_FULL); + } + + private: + std::unique_ptr d_em; + std::unique_ptr d_smt; + std::unique_ptr d_nm; + + std::unique_ptr d_rewriter; +}; /* class BagsTypeRuleBlack */ diff --git a/test/unit/theory/theory_bags_type_rules_black.h b/test/unit/theory/theory_bags_type_rules_black.h deleted file mode 100644 index d6c225bad..000000000 --- a/test/unit/theory/theory_bags_type_rules_black.h +++ /dev/null @@ -1,111 +0,0 @@ -/********************* */ -/*! \file theory_bags_type_rules_black.h - ** \verbatim - ** Top contributors (to current version): - ** Mudathir Mohamed - ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. - ** All rights reserved. See the file COPYING in the top-level source - ** directory for licensing information.\endverbatim - ** - ** \brief Black box testing of bags typing rules - **/ - -#include - -#include "expr/dtype.h" -#include "smt/smt_engine.h" -#include "theory/bags/theory_bags_type_rules.h" -#include "theory/strings/type_enumerator.h" - -using namespace CVC4; -using namespace CVC4::smt; -using namespace CVC4::theory; -using namespace CVC4::kind; -using namespace CVC4::theory::bags; -using namespace std; - -typedef expr::Attribute attribute; - -class BagsTypeRuleBlack : public CxxTest::TestSuite -{ - public: - void setUp() override - { - d_em.reset(new ExprManager()); - d_smt.reset(new SmtEngine(d_em.get())); - d_nm.reset(NodeManager::fromExprManager(d_em.get())); - d_smt->finishInit(); - } - - void tearDown() override - { - d_smt.reset(); - d_nm.release(); - d_em.reset(); - } - - std::vector getNStrings(size_t n) - { - std::vector elements(n); - CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType()); - - for (size_t i = 0; i < n; i++) - { - ++enumerator; - elements[i] = *enumerator; - } - - return elements; - } - - void testCountOperator() - { - vector elements = getNStrings(1); - Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(100))); - - Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag); - Node node = d_nm->mkConst(Rational(10)); - - // node of type Int is not compatible with bag of type (Bag String) - TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true), - TypeCheckingExceptionPrivate&); - } - - void testMkBagOperator() - { - vector elements = getNStrings(1); - Node negative = - d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1))); - Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0))); - Node positive = - d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1))); - - // only positive multiplicity are constants - TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative)); - TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero)); - TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive)); - } - - void testFromSetOperator() - { - vector elements = getNStrings(1); - Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]); - TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set)); - TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag()); - } - - void testToSetOperator() - { - vector elements = getNStrings(1); - Node bag = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(10))); - TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag)); - TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet()); - } - - private: - std::unique_ptr d_em; - std::unique_ptr d_smt; - std::unique_ptr d_nm; -}; /* class BagsTypeRuleBlack */ diff --git a/test/unit/theory/theory_bags_type_rules_white.h b/test/unit/theory/theory_bags_type_rules_white.h new file mode 100644 index 000000000..dfe2d4cac --- /dev/null +++ b/test/unit/theory/theory_bags_type_rules_white.h @@ -0,0 +1,113 @@ +/********************* */ +/*! \file theory_bags_type_rules_black.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Black box testing of bags typing rules + **/ + +#include + +#include "expr/dtype.h" +#include "smt/smt_engine.h" +#include "theory/bags/theory_bags_type_rules.h" +#include "theory/strings/type_enumerator.h" + +using namespace CVC4; +using namespace CVC4::smt; +using namespace CVC4::theory; +using namespace CVC4::kind; +using namespace CVC4::theory::bags; +using namespace std; + +typedef expr::Attribute attribute; + +class BagsTypeRuleWhite : public CxxTest::TestSuite +{ + public: + void setUp() override + { + d_em.reset(new ExprManager()); + d_smt.reset(new SmtEngine(d_em.get())); + d_nm.reset(NodeManager::fromExprManager(d_em.get())); + d_smt->finishInit(); + } + + void tearDown() override + { + d_smt.reset(); + d_nm.release(); + d_em.reset(); + } + + std::vector getNStrings(size_t n) + { + std::vector elements(n); + CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType()); + + for (size_t i = 0; i < n; i++) + { + ++enumerator; + elements[i] = *enumerator; + } + + return elements; + } + + void testCountOperator() + { + vector elements = getNStrings(1); + Node bag = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(100))); + + Node count = d_nm->mkNode(BAG_COUNT, elements[0], bag); + Node node = d_nm->mkConst(Rational(10)); + + // node of type Int is not compatible with bag of type (Bag String) + TS_ASSERT_THROWS(d_nm->mkNode(BAG_COUNT, node, bag).getType(true), + TypeCheckingExceptionPrivate&); + } + + void testMkBagOperator() + { + vector elements = getNStrings(1); + Node negative = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1))); + Node zero = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0))); + Node positive = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1))); + + // only positive multiplicity are constants + TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), negative)); + TS_ASSERT(!MkBagTypeRule::computeIsConst(d_nm.get(), zero)); + TS_ASSERT(MkBagTypeRule::computeIsConst(d_nm.get(), positive)); + } + + void testFromSetOperator() + { + vector elements = getNStrings(1); + Node set = d_nm->mkSingleton(d_nm->stringType(), elements[0]); + TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_FROM_SET, set)); + TS_ASSERT(d_nm->mkNode(BAG_FROM_SET, set).getType().isBag()); + } + + void testToSetOperator() + { + vector elements = getNStrings(1); + Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10))); + TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag)); + TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet()); + } + + private: + std::unique_ptr d_em; + std::unique_ptr d_smt; + std::unique_ptr d_nm; +}; /* class BagsTypeRuleBlack */