From 0f77646dfc0944f1f17f121ffb3112bf8b244f76 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Mon, 28 Sep 2020 08:53:07 -0500 Subject: [PATCH] Implement bags rewriter (#5132) This PR implements rewrite rules for bags. This PR focuses on rewrite rules for non constant nodes. Rewriting nodes with constant children is delegated to bags::NormalForm class (future PR). --- src/CMakeLists.txt | 8 +- src/theory/bags/bags_rewriter.cpp | 434 +++++++++++++ src/theory/bags/bags_rewriter.h | 195 ++++++ ..._bags_rewriter.cpp => bags_statistics.cpp} | 20 +- src/theory/bags/bags_statistics.h | 45 ++ src/theory/bags/kinds | 4 +- src/theory/bags/normal_form.cpp | 10 + src/theory/bags/normal_form.h | 11 +- src/theory/bags/rewrites.cpp | 76 +++ src/theory/bags/rewrites.h | 91 +++ src/theory/bags/theory_bags.cpp | 5 +- src/theory/bags/theory_bags.h | 7 +- src/theory/bags/theory_bags_rewriter.h | 38 -- src/theory/bags/theory_bags_type_enumerator.h | 2 +- src/theory/bags/theory_bags_type_rules.h | 2 +- test/unit/theory/CMakeLists.txt | 1 + test/unit/theory/theory_bags_rewriter_black.h | 593 ++++++++++++++++++ 17 files changed, 1482 insertions(+), 60 deletions(-) create mode 100644 src/theory/bags/bags_rewriter.cpp create mode 100644 src/theory/bags/bags_rewriter.h rename src/theory/bags/{theory_bags_rewriter.cpp => bags_statistics.cpp} (50%) create mode 100644 src/theory/bags/bags_statistics.h create mode 100644 src/theory/bags/rewrites.cpp create mode 100644 src/theory/bags/rewrites.h delete mode 100644 src/theory/bags/theory_bags_rewriter.h create mode 100644 test/unit/theory/theory_bags_rewriter_black.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 717378b27..74dcc39b3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -404,18 +404,22 @@ libcvc4_add_sources( theory/assertion.h theory/atom_requests.cpp theory/atom_requests.h + theory/bags/bags_rewriter.cpp + theory/bags/bags_rewriter.h + theory/bags/bags_statistics.cpp + theory/bags/bags_statistics.h theory/bags/inference_manager.cpp theory/bags/inference_manager.h theory/bags/normal_form.cpp theory/bags/normal_form.h + theory/bags/rewrites.cpp + theory/bags/rewrites.h theory/bags/solver_state.cpp theory/bags/solver_state.h theory/bags/term_registry.cpp theory/bags/term_registry.h theory/bags/theory_bags.cpp theory/bags/theory_bags.h - theory/bags/theory_bags_rewriter.cpp - theory/bags/theory_bags_rewriter.h theory/bags/theory_bags_type_enumerator.cpp theory/bags/theory_bags_type_enumerator.h theory/bags/theory_bags_type_rules.h diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp new file mode 100644 index 000000000..a6506f18e --- /dev/null +++ b/src/theory/bags/bags_rewriter.cpp @@ -0,0 +1,434 @@ +/********************* */ +/*! \file bags_rewriter.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Bags theory rewriter. + **/ + +#include "theory/bags/bags_rewriter.h" + +#include "normal_form.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace bags { + +BagsRewriteResponse::BagsRewriteResponse() + : d_node(Node::null()), d_rewrite(Rewrite::NONE) +{ +} + +BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite) + : d_node(n), d_rewrite(rewrite) +{ +} + +BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r) + : d_node(r.d_node), d_rewrite(r.d_rewrite) +{ +} + +BagsRewriter::BagsRewriter(HistogramStat* statistics) + : d_statistics(statistics) +{ + d_nm = NodeManager::currentNM(); +} + +RewriteResponse BagsRewriter::postRewrite(TNode n) +{ + BagsRewriteResponse response; + if (n.isConst()) + { + // no need to rewrite n if it is already in a normal form + response = BagsRewriteResponse(n, Rewrite::NONE); + } + else if (NormalForm::AreChildrenConstants(n)) + { + Node value = NormalForm::evaluate(n); + response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); + } + else + { + Kind k = n.getKind(); + switch (k) + { + case MK_BAG: response = rewriteMakeBag(n); break; + case BAG_COUNT: response = rewriteBagCount(n); break; + case UNION_MAX: response = rewriteUnionMax(n); break; + case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break; + case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break; + case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break; + case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break; + case BAG_CHOOSE: response = rewriteChoose(n); break; + case BAG_CARD: response = rewriteCard(n); break; + case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break; + default: response = BagsRewriteResponse(n, Rewrite::NONE); break; + } + } + + Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node + << " by " << response.d_rewrite << "." << std::endl; + + if (d_statistics != nullptr) + { + (*d_statistics) << response.d_rewrite; + } + if (response.d_node != n) + { + return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node); + } + return RewriteResponse(RewriteStatus::REWRITE_DONE, n); +} + +RewriteResponse BagsRewriter::preRewrite(TNode n) +{ + BagsRewriteResponse response; + Kind k = n.getKind(); + switch (k) + { + case EQUAL: response = rewriteEqual(n); break; + case BAG_IS_INCLUDED: response = rewriteIsIncluded(n); break; + default: response = BagsRewriteResponse(n, Rewrite::NONE); + } + + Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node + << " by " << response.d_rewrite << "." << std::endl; + + if (d_statistics != nullptr) + { + (*d_statistics) << response.d_rewrite; + } + if (response.d_node != n) + { + return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node); + } + return RewriteResponse(RewriteStatus::REWRITE_DONE, n); +} + +BagsRewriteResponse BagsRewriter::rewriteEqual(const TNode& n) const +{ + Assert(n.getKind() == EQUAL); + if (n[0] == n[1]) + { + // (= A A) = true where A is a bag + return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteIsIncluded(const TNode& n) const +{ + Assert(n.getKind() == BAG_IS_INCLUDED); + + // (bag.is_included A B) = ((difference_subtract A B) == emptybag) + Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType())); + Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]); + Node equal = subtract.eqNode(emptybag); + return BagsRewriteResponse(equal, Rewrite::SUB_BAG); +} + +BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const +{ + Assert(n.getKind() == MK_BAG); + // return emptybag for negative or zero multiplicity + if (n[1].isConst() && n[1].getConst().sgn() != 1) + { + // (mkBag x c) = emptybag where c <= 0 + Node emptybag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const +{ + Assert(n.getKind() == BAG_COUNT); + if (n[1].isConst() && n[1].getKind() == EMPTYBAG) + { + // (bag.count x emptybag) = 0 + return BagsRewriteResponse(d_nm->mkConst(Rational(0)), + Rewrite::COUNT_EMPTY); + } + if (n[1].getKind() == MK_BAG && n[0] == n[1][0]) + { + // (bag.count x (mkBag x c) = c where c > 0 is a constant + return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const +{ + Assert(n.getKind() == UNION_MAX); + if (n[1].getKind() == EMPTYBAG || n[0] == n[1]) + { + // (union_max A A) = A + // (union_max A emptybag) = A + return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY); + } + if (n[0].getKind() == EMPTYBAG) + { + // (union_max emptybag A) = A + return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY); + } + + if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT) + && (n[0] == n[1][0] || n[0] == n[1][1])) + { + // (union_max A (union_max A B)) = (union_max A B) + // (union_max A (union_max B A)) = (union_max B A) + // (union_max A (union_disjoint A B)) = (union_disjoint A B) + // (union_max A (union_disjoint B A)) = (union_disjoint B A) + return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT); + } + + if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT) + && (n[0][0] == n[1] || n[0][1] == n[1])) + { + // (union_max (union_max A B) A)) = (union_max A B) + // (union_max (union_max B A) A)) = (union_max B A) + // (union_max (union_disjoint A B) A)) = (union_disjoint A B) + // (union_max (union_disjoint B A) A)) = (union_disjoint B A) + return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const +{ + Assert(n.getKind() == UNION_DISJOINT); + if (n[1].getKind() == EMPTYBAG) + { + // (union_disjoint A emptybag) = A + return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT); + } + if (n[0].getKind() == EMPTYBAG) + { + // (union_disjoint emptybag A) = A + return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT); + } + if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN) + || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN)) + + { + // (union_disjoint (union_max A B) (intersection_min A B)) = + // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) + // check if the operands of union_max and intersection_min are the same + std::set left(n[0].begin(), n[0].end()); + std::set right(n[0].begin(), n[0].end()); + if (left == right) + { + Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]); + return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN); + } + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const +{ + Assert(n.getKind() == INTERSECTION_MIN); + if (n[0].getKind() == EMPTYBAG) + { + // (intersection_min emptybag A) = emptybag + return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT); + } + if (n[1].getKind() == EMPTYBAG) + { + // (intersection_min A emptybag) = emptybag + return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT); + } + if (n[0] == n[1]) + { + // (intersection_min A A) = A + return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME); + } + if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX) + { + if (n[0] == n[1][0] || n[0] == n[1][1]) + { + // (intersection_min A (union_disjoint A B)) = A + // (intersection_min A (union_disjoint B A)) = A + // (intersection_min A (union_max A B)) = A + // (intersection_min A (union_max B A)) = A + return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT); + } + } + + if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX) + { + if (n[1] == n[0][0] || n[1] == n[0][1]) + { + // (intersection_min (union_disjoint A B) A) = A + // (intersection_min (union_disjoint B A) A) = A + // (intersection_min (union_max A B) A) = A + // (intersection_min (union_max B A) A) = A + return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT); + } + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract( + const TNode& n) const +{ + Assert(n.getKind() == DIFFERENCE_SUBTRACT); + if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG) + { + // (difference_subtract A emptybag) = A + // (difference_subtract emptybag A) = emptybag + return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT); + } + if (n[0] == n[1]) + { + // (difference_subtract A A) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME); + } + + if (n[0].getKind() == UNION_DISJOINT) + { + if (n[1] == n[0][0]) + { + // (difference_subtract (union_disjoint A B) A) = B + return BagsRewriteResponse(n[0][1], + Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT); + } + if (n[1] == n[0][1]) + { + // (difference_subtract (union_disjoint B A) A) = B + return BagsRewriteResponse(n[0][0], + Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT); + } + } + + if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX) + { + if (n[0] == n[1][0] || n[0] == n[1][1]) + { + // (difference_subtract A (union_disjoint A B)) = emptybag + // (difference_subtract A (union_disjoint B A)) = emptybag + // (difference_subtract A (union_max A B)) = emptybag + // (difference_subtract A (union_max B A)) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION); + } + } + + if (n[0].getKind() == INTERSECTION_MIN) + { + if (n[1] == n[0][0] || n[1] == n[0][1]) + { + // (difference_subtract (intersection_min A B) A) = emptybag + // (difference_subtract (intersection_min B A) A) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN); + } + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const +{ + Assert(n.getKind() == DIFFERENCE_REMOVE); + + if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG) + { + // (difference_remove A emptybag) = A + // (difference_remove emptybag B) = emptybag + return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT); + } + + if (n[0] == n[1]) + { + // (difference_remove A A) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME); + } + + if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX) + { + if (n[0] == n[1][0] || n[0] == n[1][1]) + { + // (difference_remove A (union_disjoint A B)) = emptybag + // (difference_remove A (union_disjoint B A)) = emptybag + // (difference_remove A (union_max A B)) = emptybag + // (difference_remove A (union_max B A)) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION); + } + } + + if (n[0].getKind() == INTERSECTION_MIN) + { + if (n[1] == n[0][0] || n[1] == n[0][1]) + { + // (difference_remove (intersection_min A B) A) = emptybag + // (difference_remove (intersection_min B A) A) = emptybag + Node emptyBag = d_nm->mkConst(EmptyBag(n.getType())); + return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN); + } + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const +{ + Assert(n.getKind() == BAG_CHOOSE); + if (n[0].getKind() == MK_BAG && n[0][1].isConst()) + { + // (bag.choose (mkBag x c)) = x where c is a constant > 0 + return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const +{ + Assert(n.getKind() == BAG_CARD); + if (n[0].getKind() == MK_BAG && n[0][1].isConst()) + { + // (bag.card (mkBag x c)) = c where c is a constant > 0 + return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG); + } + + if (n[0].getKind() == UNION_DISJOINT) + { + // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) + Node A = d_nm->mkNode(BAG_CARD, n[0][0]); + Node B = d_nm->mkNode(BAG_CARD, n[0][1]); + Node plus = d_nm->mkNode(PLUS, A, B); + return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT); + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + +BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const +{ + Assert(n.getKind() == BAG_IS_SINGLETON); + if (n[0].getKind() == MK_BAG) + { + // (bag.is_singleton (mkBag x c)) = (c == 1) + Node one = d_nm->mkConst(Rational(1)); + Node equal = n[0][1].eqNode(one); + return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + +} // namespace bags +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h new file mode 100644 index 000000000..f0998a205 --- /dev/null +++ b/src/theory/bags/bags_rewriter.h @@ -0,0 +1,195 @@ +/********************* */ +/*! \file bags_rewriter.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Bags theory rewriter. + **/ + +#include "cvc4_private.h" +#include "theory/bags/rewrites.h" + +#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H +#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H + +#include "theory/rewriter.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +/** a class represents the result of rewriting bag nodes */ +struct BagsRewriteResponse +{ + BagsRewriteResponse(); + BagsRewriteResponse(Node n, Rewrite rewrite); + BagsRewriteResponse(const BagsRewriteResponse& r); + /** the rewritten node */ + Node d_node; + /** type of rewrite used by bags */ + Rewrite d_rewrite; + +}; /* struct BagsRewriteResponse */ + +class BagsRewriter : public TheoryRewriter +{ + public: + BagsRewriter(HistogramStat* statistics = nullptr); + + /** + * postRewrite nodes with kinds: MK_BAG, BAG_COUNT, UNION_MAX, UNION_DISJOINT, + * INTERSECTION_MIN, DIFFERENCE_SUBTRACT, DIFFERENCE_REMOVE, BAG_CHOOSE, + * BAG_CARD, BAG_IS_SINGLETON. + * See the rewrite rules for these kinds below. + */ + RewriteResponse postRewrite(TNode n) override; + /** + * preRewrite nodes with kinds: EQUAL, BAG_IS_INCLUDED. + * See the rewrite rules for these kinds below. + */ + RewriteResponse preRewrite(TNode n) override; + + private: + /** + * rewrites for n include: + * - (= A A) = true where A is a bag + */ + BagsRewriteResponse rewriteEqual(const TNode& n) const; + + /** + * rewrites for n include: + * - (bag.is_included A B) = ((difference_subtract A B) == emptybag) + */ + BagsRewriteResponse rewriteIsIncluded(const TNode& n) const; + + /** + * rewrites for n include: + * - (mkBag x 0) = (emptybag T) where T is the type of x + * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a + * constant + * - otherwise = n + */ + BagsRewriteResponse rewriteMakeBag(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.count x emptybag) = 0 + * - (bag.count x (mkBag x c) = c where c > 0 is a constant + * - otherwise = n + */ + BagsRewriteResponse rewriteBagCount(const TNode& n) const; + + /** + * rewrites for n include: + * - (union_max A emptybag) = A + * - (union_max emptybag A) = A + * - (union_max A A) = A + * - (union_max A (union_max A B)) = (union_max A B) + * - (union_max A (union_max B A)) = (union_max B A) + * - (union_max (union_max A B) A) = (union_max A B) + * - (union_max (union_max B A) A) = (union_max B A) + * - (union_max A (union_disjoint A B)) = (union_disjoint A B) + * - (union_max A (union_disjoint B A)) = (union_disjoint B A) + * - (union_max (union_disjoint A B) A) = (union_disjoint A B) + * - (union_max (union_disjoint B A) A) = (union_disjoint B A) + * - otherwise = n + */ + BagsRewriteResponse rewriteUnionMax(const TNode& n) const; + + /** + * rewrites for n include: + * - (union_disjoint A emptybag) = A + * - (union_disjoint emptybag A) = A + * - (union_disjoint (union_max A B) (intersection_min A B)) = + * (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) + * - other permutations of the above like swapping A and B, or swapping + * intersection_min and union_max + * - otherwise = n + */ + BagsRewriteResponse rewriteUnionDisjoint(const TNode& n) const; + + /** + * rewrites for n include: + * - (intersection_min A emptybag) = emptybag + * - (intersection_min emptybag A) = emptybag + * - (intersection_min A A) = A + * - (intersection_min A (union_disjoint A B)) = A + * - (intersection_min A (union_disjoint B A)) = A + * - (intersection_min (union_disjoint A B) A) = A + * - (intersection_min (union_disjoint B A) A) = A + * - (intersection_min A (union_max A B)) = A + * - (intersection_min A (union_max B A)) = A + * - (intersection_min (union_max A B) A) = A + * - (intersection_min (union_max B A) A) = A + * - otherwise = n + */ + BagsRewriteResponse rewriteIntersectionMin(const TNode& n) const; + + /** + * rewrites for n include: + * - (difference_subtract A emptybag) = A + * - (difference_subtract emptybag A) = emptybag + * - (difference_subtract A A) = emptybag + * - (difference_subtract (union_disjoint A B) A) = B + * - (difference_subtract (union_disjoint B A) A) = B + * - (difference_subtract A (union_disjoint A B)) = emptybag + * - (difference_subtract A (union_disjoint B A)) = emptybag + * - (difference_subtract A (union_max A B)) = emptybag + * - (difference_subtract A (union_max B A)) = emptybag + * - (difference_subtract (intersection_min A B) A) = emptybag + * - (difference_subtract (intersection_min B A) A) = emptybag + * - otherwise = n + */ + BagsRewriteResponse rewriteDifferenceSubtract(const TNode& n) const; + + /** + * rewrites for n include: + * - (difference_remove A emptybag) = A + * - (difference_remove emptybag A) = emptybag + * - (difference_remove A A) = emptybag + * - (difference_remove A (union_disjoint A B)) = emptybag + * - (difference_remove A (union_disjoint B A)) = emptybag + * - (difference_remove A (union_max A B)) = emptybag + * - (difference_remove A (union_max B A)) = emptybag + * - (difference_remove (intersection_min A B) A) = emptybag + * - (difference_remove (intersection_min B A) A) = emptybag + * - otherwise = n + */ + BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.choose (mkBag x c)) = x where c is a constant > 0 + * - otherwise = n + */ + BagsRewriteResponse rewriteChoose(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.card (mkBag x c)) = c where c is a constant > 0 + * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) + * - otherwise = n + */ + BagsRewriteResponse rewriteCard(const TNode& n) const; + + /** + * rewrites for n include: + * - (bag.is_singleton (mkBag x c)) = (c == 1) + */ + BagsRewriteResponse rewriteIsSingleton(const TNode& n) const; + + private: + /** Reference to the rewriter statistics. */ + NodeManager* d_nm; + /** Reference to the rewriter statistics. */ + HistogramStat* d_statistics; +}; /* class TheoryBagsRewriter */ + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */ diff --git a/src/theory/bags/theory_bags_rewriter.cpp b/src/theory/bags/bags_statistics.cpp similarity index 50% rename from src/theory/bags/theory_bags_rewriter.cpp rename to src/theory/bags/bags_statistics.cpp index aaf0ab98c..ea3d3046e 100644 --- a/src/theory/bags/theory_bags_rewriter.cpp +++ b/src/theory/bags/bags_statistics.cpp @@ -1,35 +1,33 @@ /********************* */ -/*! \file theory_bags_rewriter.cpp +/*! \file bags_statistics.cpp ** \verbatim ** Top contributors (to current version): ** Mudathir Mohamed ** This file is part of the CVC4 project. ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. + ** in the top-level source directory and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief Bags theory rewriter. + ** \brief Statistics for the theory of bags **/ -#include "theory/bags/theory_bags_rewriter.h" +#include "theory/bags/bags_statistics.h" -using namespace CVC4::kind; +#include "smt/smt_statistics_registry.h" namespace CVC4 { namespace theory { namespace bags { -RewriteResponse TheoryBagsRewriter::postRewrite(TNode node) +BagsStatistics::BagsStatistics() : d_rewrites("theory::bags::rewrites") { - // TODO(projects#225): complete the code here - return RewriteResponse(REWRITE_DONE, node); + smtStatisticsRegistry()->registerStat(&d_rewrites); } -RewriteResponse TheoryBagsRewriter::preRewrite(TNode node) +BagsStatistics::~BagsStatistics() { - // TODO(projects#225): complete the code here - return RewriteResponse(REWRITE_DONE, node); + smtStatisticsRegistry()->unregisterStat(&d_rewrites); } } // namespace bags diff --git a/src/theory/bags/bags_statistics.h b/src/theory/bags/bags_statistics.h new file mode 100644 index 000000000..457e3a32e --- /dev/null +++ b/src/theory/bags/bags_statistics.h @@ -0,0 +1,45 @@ +/********************* */ +/*! \file bags_statistics.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Statistics for the theory of bags + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__BAGS_STATISTICS_H +#define CVC4__THEORY__BAGS_STATISTICS_H + +#include "expr/kind.h" +#include "theory/bags/rewrites.h" +#include "util/statistics_registry.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +/** + * Statistics for the theory of bags. + */ +class BagsStatistics +{ + public: + BagsStatistics(); + ~BagsStatistics(); + + /** Counts the number of applications of each type of rewrite rule */ + HistogramStat d_rewrites; +}; + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAGS_STATISTICS_H */ diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 8093448a0..cdbef58de 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -8,8 +8,8 @@ theory THEORY_BAGS \ ::CVC4::theory::bags::TheoryBags \ "theory/bags/theory_bags.h" typechecker "theory/bags/theory_bags_type_rules.h" -rewriter ::CVC4::theory::bags::TheoryBagsRewriter \ - "theory/bags/theory_bags_rewriter.h" +rewriter ::CVC4::theory::bags::BagsRewriter \ + "theory/bags/bags_rewriter.h" properties parametric properties check propagate presolve diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index d9248615b..facad3c92 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -22,6 +22,16 @@ bool NormalForm::checkNormalConstant(TNode n) return false; } +bool NormalForm::AreChildrenConstants(TNode n) +{ + return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); +} + +Node NormalForm::evaluate(TNode n) +{ + // TODO(projects#223): complete this function + return CVC4::Node(); +} } // namespace bags } // namespace theory } // namespace CVC4 \ No newline at end of file diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index 73fd8dba8..8c719fe81 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -12,9 +12,10 @@ ** \brief Normal form for bag constants. **/ -#include "cvc4_private.h" #include +#include "cvc4_private.h" + #ifndef CVC4__THEORY__BAGS__NORMAL_FORM_H #define CVC4__THEORY__BAGS__NORMAL_FORM_H @@ -36,6 +37,14 @@ class NormalForm * Also handles the corner cases of empty bag and singleton bag. */ static bool checkNormalConstant(TNode n); + /** + * check whether all children of the given node are in normal form + */ + static bool AreChildrenConstants(TNode n); + /** + * evaluate the node n to a constant value + */ + static Node evaluate(TNode n); }; } // namespace bags } // namespace theory diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp new file mode 100644 index 000000000..758f8a6e6 --- /dev/null +++ b/src/theory/bags/rewrites.cpp @@ -0,0 +1,76 @@ +/********************* */ +/*! \file rewrites.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of inference information utility. + **/ + +#include "theory/bags/rewrites.h" + +#include + +namespace CVC4 { +namespace theory { +namespace bags { + +const char* toString(Rewrite r) +{ + switch (r) + { + case Rewrite::NONE: return "NONE"; + case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT"; + case Rewrite::CARD_MK_BAG: return "CARD_MK_BAG"; + case Rewrite::CHOOSE_MK_BAG: return "CHOOSE_MK_BAG"; + case Rewrite::CONSTANT_EVALUATION: return "CONSTANT_EVALUATION"; + case Rewrite::COUNT_EMPTY: return "COUNT_EMPTY"; + case Rewrite::COUNT_MK_BAG: return "COUNT_MK_BAG"; + case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES"; + case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT"; + case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT"; + case Rewrite::INTERSECTION_SAME: return "INTERSECTION_SAME"; + case Rewrite::INTERSECTION_SHARED_LEFT: return "INTERSECTION_SHARED_LEFT"; + case Rewrite::INTERSECTION_SHARED_RIGHT: return "INTERSECTION_SHARED_RIGHT"; + case Rewrite::IS_SINGLETON_MK_BAG: return "IS_SINGLETON_MK_BAG"; + case Rewrite::MK_BAG_COUNT_NEGATIVE: return "MK_BAG_COUNT_NEGATIVE"; + case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION"; + case Rewrite::REMOVE_MIN: return "REMOVE_MIN"; + case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT"; + case Rewrite::REMOVE_SAME: return "REMOVE_SAME"; + case Rewrite::SUB_BAG: return "SUB_BAG"; + case Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT: + return "SUBTRACT_DISJOINT_SHARED_LEFT"; + case Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT: + return "SUBTRACT_DISJOINT_SHARED_RIGHT"; + case Rewrite::SUBTRACT_FROM_UNION: return "SUBTRACT_FROM_UNION"; + case Rewrite::SUBTRACT_MIN: return "SUBTRACT_MIN"; + case Rewrite::SUBTRACT_RETURN_LEFT: return "SUBTRACT_RETURN_LEFT"; + case Rewrite::SUBTRACT_SAME: return "SUBTRACT_SAME"; + case Rewrite::UNION_DISJOINT_EMPTY_LEFT: return "UNION_DISJOINT_EMPTY_LEFT"; + case Rewrite::UNION_DISJOINT_EMPTY_RIGHT: + return "UNION_DISJOINT_EMPTY_RIGHT"; + case Rewrite::UNION_DISJOINT_MAX_MIN: return "UNION_DISJOINT_MAX_MIN"; + case Rewrite::UNION_MAX_EMPTY: return "UNION_MAX_EMPTY"; + case Rewrite::UNION_MAX_SAME_OR_EMPTY: return "UNION_MAX_SAME_OR_EMPTY"; + case Rewrite::UNION_MAX_UNION_LEFT: return "UNION_MAX_UNION_LEFT"; + case Rewrite::UNION_MAX_UNION_RIGHT: return "UNION_MAX_UNION_RIGHT"; + + default: return "?"; + } +} + +std::ostream& operator<<(std::ostream& out, Rewrite r) +{ + out << toString(r); + return out; +} + +} // namespace bags +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h new file mode 100644 index 000000000..13b0ff202 --- /dev/null +++ b/src/theory/bags/rewrites.h @@ -0,0 +1,91 @@ +/********************* */ +/*! \file rewrites.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Type for rewrites for bags. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__BAGS__REWRITES_H +#define CVC4__THEORY__BAGS__REWRITES_H + +#include + +namespace CVC4 { +namespace theory { +namespace bags { + +/** Types of rewrites used by bags + * + * This rewrites are documented where they are used in the rewriter. + */ +enum class Rewrite : uint32_t +{ + NONE, // no rewrite happened + CARD_DISJOINT, + CARD_MK_BAG, + CHOOSE_MK_BAG, + CONSTANT_EVALUATION, + COUNT_EMPTY, + COUNT_MK_BAG, + IDENTICAL_NODES, + INTERSECTION_EMPTY_LEFT, + INTERSECTION_EMPTY_RIGHT, + INTERSECTION_SAME, + INTERSECTION_SHARED_LEFT, + INTERSECTION_SHARED_RIGHT, + IS_SINGLETON_MK_BAG, + MK_BAG_COUNT_NEGATIVE, + REMOVE_FROM_UNION, + REMOVE_MIN, + REMOVE_RETURN_LEFT, + REMOVE_SAME, + SUB_BAG, + SUBTRACT_DISJOINT_SHARED_LEFT, + SUBTRACT_DISJOINT_SHARED_RIGHT, + SUBTRACT_FROM_UNION, + SUBTRACT_MIN, + SUBTRACT_RETURN_LEFT, + SUBTRACT_SAME, + UNION_DISJOINT_EMPTY_LEFT, + UNION_DISJOINT_EMPTY_RIGHT, + UNION_DISJOINT_MAX_MIN, + UNION_MAX_EMPTY, + UNION_MAX_SAME_OR_EMPTY, + UNION_MAX_UNION_LEFT, + UNION_MAX_UNION_RIGHT +}; + +/** + * Converts an rewrite to a string. Note: This function is also used in + * `safe_print()`. Changing this functions name or signature will result in + * `safe_print()` printing "" instead of the proper strings for + * the enum values. + * + * @param r The rewrite + * @return The name of the rewrite + */ +const char* toString(Rewrite r); + +/** + * Writes an rewrite name to a stream. + * + * @param out The stream to write to + * @param r The rewrite to write to the stream + * @return The stream + */ +std::ostream& operator<<(std::ostream& out, Rewrite r); + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAGS__REWRITES_H */ diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 5ddd17302..e4cd64b48 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -29,8 +29,9 @@ TheoryBags::TheoryBags(context::Context* c, : Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm), d_state(c, u, valuation), d_im(*this, d_state, pnm), - d_rewriter(), - d_notify(*this, d_im) + d_notify(*this, d_im), + d_statistics(), + d_rewriter(&d_statistics.d_rewrites) { // use the official theory state and inference manager objects d_theoryState = &d_state; diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h index 44f7ae1b0..08bc5f33a 100644 --- a/src/theory/bags/theory_bags.h +++ b/src/theory/bags/theory_bags.h @@ -19,9 +19,10 @@ #include +#include "theory/bags/bags_rewriter.h" +#include "theory/bags/bags_statistics.h" #include "theory/bags/inference_manager.h" #include "theory/bags/solver_state.h" -#include "theory/bags/theory_bags_rewriter.h" #include "theory/theory.h" #include "theory/theory_eq_notify.h" #include "theory/uf/equality_engine.h" @@ -95,8 +96,10 @@ class TheoryBags : public Theory InferenceManager d_im; /** Instance of the above class */ NotifyClass d_notify; + /** Statistics for the theory of bags. */ + BagsStatistics d_statistics; /** The theory rewriter for this theory. */ - TheoryBagsRewriter d_rewriter; + BagsRewriter d_rewriter; void eqNotifyNewClass(TNode t); void eqNotifyMerge(TNode t1, TNode t2); diff --git a/src/theory/bags/theory_bags_rewriter.h b/src/theory/bags/theory_bags_rewriter.h deleted file mode 100644 index 7be88636a..000000000 --- a/src/theory/bags/theory_bags_rewriter.h +++ /dev/null @@ -1,38 +0,0 @@ -/********************* */ -/*! \file theory_bags_rewriter.h - ** \verbatim - ** Top contributors (to current version): - ** Mudathir Mohamed - ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. - ** All rights reserved. See the file COPYING in the top-level source - ** directory for licensing information.\endverbatim - ** - ** \brief Bags theory rewriter. - **/ - -#include "cvc4_private.h" - -#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H -#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H - -#include "theory/rewriter.h" - -namespace CVC4 { -namespace theory { -namespace bags { - -class TheoryBagsRewriter : public TheoryRewriter -{ - public: - RewriteResponse postRewrite(TNode node) override; - - RewriteResponse preRewrite(TNode node) override; -}; /* class TheoryBagsRewriter */ - -} // namespace bags -} // namespace theory -} // namespace CVC4 - -#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */ diff --git a/src/theory/bags/theory_bags_type_enumerator.h b/src/theory/bags/theory_bags_type_enumerator.h index 26639afd8..a1ba896c1 100644 --- a/src/theory/bags/theory_bags_type_enumerator.h +++ b/src/theory/bags/theory_bags_type_enumerator.h @@ -66,7 +66,7 @@ class BagEnumerator : public TypeEnumeratorBase * * This seems too expensive to implement. * For now we are implementing an obvious solution - * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both fininte and infinite + * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both finite and infinite * types */ BagEnumerator& operator++() override; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index fc5a19348..e4279479d 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -236,7 +236,7 @@ struct BagsProperties { static Cardinality computeCardinality(TypeNode type) { - return Cardinality::UNKNOWN_CARD; + return Cardinality::INTEGERS; } static bool isWellFounded(TypeNode type) { return type[0].isWellFounded(); } diff --git a/test/unit/theory/CMakeLists.txt b/test/unit/theory/CMakeLists.txt index f40d9658b..108471d4a 100644 --- a/test/unit/theory/CMakeLists.txt +++ b/test/unit/theory/CMakeLists.txt @@ -14,6 +14,7 @@ cvc4_add_unit_test_white(evaluator_white theory) cvc4_add_unit_test_white(logic_info_white theory) cvc4_add_unit_test_white(sequences_rewriter_white theory) cvc4_add_unit_test_white(theory_arith_white theory) +cvc4_add_unit_test_white(theory_bags_rewriter_black theory) cvc4_add_unit_test_white(theory_bags_type_rules_black theory) cvc4_add_unit_test_white(theory_bv_rewriter_white theory) cvc4_add_unit_test_white(theory_bv_white theory) diff --git a/test/unit/theory/theory_bags_rewriter_black.h b/test/unit/theory/theory_bags_rewriter_black.h new file mode 100644 index 000000000..d51805854 --- /dev/null +++ b/test/unit/theory/theory_bags_rewriter_black.h @@ -0,0 +1,593 @@ +/********************* */ +/*! \file theory_bags_rewriter_black.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Black box testing of bags rewriter + **/ + +#include + +#include "expr/dtype.h" +#include "smt/smt_engine.h" +#include "theory/bags/bags_rewriter.h" +#include "theory/strings/type_enumerator.h" + +using namespace CVC4; +using namespace CVC4::smt; +using namespace CVC4::theory; +using namespace CVC4::kind; +using namespace CVC4::theory::bags; +using namespace std; + +typedef expr::Attribute attribute; + +class BagsTypeRuleBlack : public CxxTest::TestSuite +{ + public: + void setUp() override + { + d_em.reset(new ExprManager()); + d_smt.reset(new SmtEngine(d_em.get())); + d_nm.reset(NodeManager::fromExprManager(d_em.get())); + d_smt->finishInit(); + d_rewriter.reset(new BagsRewriter(nullptr)); + } + + void tearDown() override + { + d_rewriter.reset(); + d_smt.reset(); + d_nm.release(); + d_em.reset(); + } + + std::vector getNStrings(size_t n) + { + std::vector elements(n); + for (size_t i = 0; i < n; i++) + { + elements[i] = d_nm->mkSkolem("x", d_nm->stringType()); + } + return elements; + } + + void testEmptyBagNormalForm() + { + Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType())); + // empty bags are in normal form + TS_ASSERT(emptybag.isConst()); + RewriteResponse response = d_rewriter->postRewrite(emptybag); + TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE); + } + + void testBagEquality() + { + vector elements = getNStrings(2); + Node x = elements[0]; + Node y = elements[1]; + Node c = d_nm->mkSkolem("c", d_nm->integerType()); + Node d = d_nm->mkSkolem("d", d_nm->integerType()); + Node bagX = d_nm->mkNode(MK_BAG, x, c); + Node bagY = d_nm->mkNode(MK_BAG, y, d); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + + // (= A A) = true where A is a bag + Node n1 = emptyBag.eqNode(emptyBag); + RewriteResponse response1 = d_rewriter->preRewrite(n1); + TS_ASSERT(response1.d_node == d_nm->mkConst(true) + && response1.d_status == REWRITE_AGAIN_FULL); + } + + void testMkBagConstantElement() + { + vector elements = getNStrings(1); + Node negative = + d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1))); + Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0))); + Node positive = + d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); + RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); + RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); + + // bags with non-positive multiplicity are rewritten as empty bags + TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL + && negativeResponse.d_node == emptybag); + TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL + && zeroResponse.d_node == emptybag); + + // no change for positive + TS_ASSERT(positiveResponse.d_status == REWRITE_DONE + && positive == positiveResponse.d_node); + } + + void testMkBagVariableElement() + { + Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); + Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1))); + Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1))); + Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0))); + Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + RewriteResponse negativeResponse = d_rewriter->postRewrite(negative); + RewriteResponse zeroResponse = d_rewriter->postRewrite(zero); + RewriteResponse positiveResponse = d_rewriter->postRewrite(positive); + + // bags with non-positive multiplicity are rewritten as empty bags + TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL + && negativeResponse.d_node == emptybag); + TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL + && zeroResponse.d_node == emptybag); + + // no change for positive + TS_ASSERT(positiveResponse.d_status == REWRITE_DONE + && positive == positiveResponse.d_node); + } + + void testBagCount() + { + int n = 3; + Node skolem = d_nm->mkSkolem("x", d_nm->stringType()); + Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType()))); + Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n))); + + // (bag.count x emptybag) = 0 + Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL + && response1.d_node == d_nm->mkConst(Rational(0))); + + // (bag.count x (mkBag x c) = c where c > 0 is a constant + Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL + && response2.d_node == d_nm->mkConst(Rational(n))); + } + + void testUnionMax() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + + // (union_max A emptybag) = A + Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(unionMax1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (union_max emptybag A) = A + Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(unionMax2); + TS_ASSERT(response2.d_node == A + && response2.d_status == REWRITE_AGAIN_FULL); + + // (union_max A A) = A + Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(unionMax3); + TS_ASSERT(response3.d_node == A + && response3.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_max A B)) = (union_max A B) + Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB); + RewriteResponse response4 = d_rewriter->postRewrite(unionMax4); + TS_ASSERT(response4.d_node == unionMaxAB + && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_max B A)) = (union_max B A) + Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA); + RewriteResponse response5 = d_rewriter->postRewrite(unionMax5); + TS_ASSERT(response5.d_node == unionMaxBA + && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_max A B) A) = (union_max A B) + Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A); + RewriteResponse response6 = d_rewriter->postRewrite(unionMax6); + TS_ASSERT(response6.d_node == unionMaxAB + && response6.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_max B A) A) = (union_max B A) + Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A); + RewriteResponse response7 = d_rewriter->postRewrite(unionMax7); + TS_ASSERT(response7.d_node == unionMaxBA + && response7.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_disjoint A B)) = (union_disjoint A B) + Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB); + RewriteResponse response8 = d_rewriter->postRewrite(unionMax8); + TS_ASSERT(response8.d_node == unionDisjointAB + && response8.d_status == REWRITE_AGAIN_FULL); + + // (union_max A (union_disjoint B A)) = (union_disjoint B A) + Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA); + RewriteResponse response9 = d_rewriter->postRewrite(unionMax9); + TS_ASSERT(response9.d_node == unionDisjointBA + && response9.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_disjoint A B) A) = (union_disjoint A B) + Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(unionMax10); + TS_ASSERT(response10.d_node == unionDisjointAB + && response10.d_status == REWRITE_AGAIN_FULL); + + // (union_max (union_disjoint B A) A) = (union_disjoint B A) + Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(unionMax11); + TS_ASSERT(response11.d_node == unionDisjointBA + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testUnionDisjoint() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (union_disjoint A emptybag) = A + Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint emptybag A) = A + Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2); + TS_ASSERT(response2.d_node == A + && response2.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (union_max A B) (intersection_min B A)) = + // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint3 = + d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA); + RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3); + TS_ASSERT(response3.d_node == unionDisjointAB + && response3.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (intersection_min B A)) (union_max A B) = + // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint4 = + d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA); + RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4); + TS_ASSERT(response4.d_node == unionDisjointBA + && response4.d_status == REWRITE_AGAIN_FULL); + } + + void testIntersectionMin() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + + // (intersection_min A emptybag) = emptyBag + Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == emptyBag + && response1.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min emptybag A) = emptyBag + Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A A) = A + Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == A + && response3.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_max A B) = A + Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB); + RewriteResponse response4 = d_rewriter->postRewrite(n4); + TS_ASSERT(response4.d_node == A + && response4.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_max B A) = A + Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA); + RewriteResponse response5 = d_rewriter->postRewrite(n5); + TS_ASSERT(response5.d_node == A + && response4.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_max A B) A) = A + Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == A + && response6.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_max B A) A) = A + Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == A + && response7.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_disjoint A B) = A + Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == A + && response8.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min A (union_disjoint B A) = A + Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == A + && response9.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_disjoint A B) A) = A + Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == A + && response10.d_status == REWRITE_AGAIN_FULL); + + // (intersection_min (union_disjoint B A) A) = A + Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == A + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testDifferenceSubtract() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (difference_subtract A emptybag) = A + Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract emptybag A) = emptyBag + Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A A) = emptybag + Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == emptyBag + && response3.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (union_disjoint A B) A) = B + Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A); + RewriteResponse response4 = d_rewriter->postRewrite(n4); + TS_ASSERT(response4.d_node == B + && response4.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (union_disjoint B A) A) = B + Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A); + RewriteResponse response5 = d_rewriter->postRewrite(n5); + TS_ASSERT(response5.d_node == B + && response4.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_disjoint A B)) = emptybag + Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == emptyBag + && response6.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_disjoint B A)) = emptybag + Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == emptyBag + && response7.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_max A B)) = emptybag + Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == emptyBag + && response8.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract A (union_max B A)) = emptybag + Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == emptyBag + && response9.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (intersection_min A B) A) = emptybag + Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == emptyBag + && response10.d_status == REWRITE_AGAIN_FULL); + + // (difference_subtract (intersection_min B A) A) = emptybag + Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == emptyBag + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testDifferenceRemove() + { + int n = 3; + vector elements = getNStrings(2); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1))); + Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); + Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); + + // (difference_remove A emptybag) = A + Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == A + && response1.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove emptybag A) = emptyBag + Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == emptyBag + && response2.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A A) = emptybag + Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == emptyBag + && response3.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_disjoint A B)) = emptybag + Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB); + RewriteResponse response6 = d_rewriter->postRewrite(n6); + TS_ASSERT(response6.d_node == emptyBag + && response6.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_disjoint B A)) = emptybag + Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA); + RewriteResponse response7 = d_rewriter->postRewrite(n7); + TS_ASSERT(response7.d_node == emptyBag + && response7.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_max A B)) = emptybag + Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB); + RewriteResponse response8 = d_rewriter->postRewrite(n8); + TS_ASSERT(response8.d_node == emptyBag + && response8.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove A (union_max B A)) = emptybag + Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA); + RewriteResponse response9 = d_rewriter->postRewrite(n9); + TS_ASSERT(response9.d_node == emptyBag + && response9.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove (intersection_min A B) A) = emptybag + Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A); + RewriteResponse response10 = d_rewriter->postRewrite(n10); + TS_ASSERT(response10.d_node == emptyBag + && response10.d_status == REWRITE_AGAIN_FULL); + + // (difference_remove (intersection_min B A) A) = emptybag + Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A); + RewriteResponse response11 = d_rewriter->postRewrite(n11); + TS_ASSERT(response11.d_node == emptyBag + && response11.d_status == REWRITE_AGAIN_FULL); + } + + void testChoose() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node c = d_nm->mkConst(Rational(3)); + Node bag = d_nm->mkNode(MK_BAG, x, c); + + // (bag.choose (mkBag x c)) = x where c is a constant > 0 + Node n1 = d_nm->mkNode(BAG_CHOOSE, bag); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TS_ASSERT(response1.d_node == x + && response1.d_status == REWRITE_AGAIN_FULL); + } + + void testBagCard() + { + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node emptyBag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node zero = d_nm->mkConst(Rational(0)); + Node c = d_nm->mkConst(Rational(3)); + Node bag = d_nm->mkNode(MK_BAG, x, c); + vector elements = getNStrings(2); + Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4))); + Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + + // TODO(projects#223): enable this test after implementing bags normal form + // // (bag.card emptybag) = 0 + // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag); + // RewriteResponse response1 = d_rewriter->postRewrite(n1); + // TS_ASSERT(response1.d_node == zero && response1.d_status == + // REWRITE_AGAIN_FULL); + + // (bag.card (mkBag x c)) = c where c is a constant > 0 + Node n2 = d_nm->mkNode(BAG_CARD, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + TS_ASSERT(response2.d_node == c + && response2.d_status == REWRITE_AGAIN_FULL); + + // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) + Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB); + Node cardA = d_nm->mkNode(BAG_CARD, A); + Node cardB = d_nm->mkNode(BAG_CARD, B); + Node plus = d_nm->mkNode(PLUS, cardA, cardB); + RewriteResponse response3 = d_rewriter->postRewrite(n3); + TS_ASSERT(response3.d_node == plus + && response3.d_status == REWRITE_AGAIN_FULL); + } + + void testIsSingleton() + { + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkSkolem("x", d_nm->stringType()); + Node c = d_nm->mkSkolem("c", d_nm->integerType()); + Node bag = d_nm->mkNode(MK_BAG, x, c); + + // TODO(projects#223): complete this function + // (bag.is_singleton emptybag) = false + // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag); + // RewriteResponse response1 = d_rewriter->postRewrite(n1); + // TS_ASSERT(response1.d_node == d_nm->mkConst(false) + // && response1.d_status == REWRITE_AGAIN_FULL); + + // (bag.is_singleton (mkBag x c) = (c == 1) + Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag); + RewriteResponse response2 = d_rewriter->postRewrite(n2); + Node one = d_nm->mkConst(Rational(1)); + Node equal = c.eqNode(one); + TS_ASSERT(response2.d_node == equal + && response2.d_status == REWRITE_AGAIN_FULL); + } + + private: + std::unique_ptr d_em; + std::unique_ptr d_smt; + std::unique_ptr d_nm; + + std::unique_ptr d_rewriter; +}; /* class BagsTypeRuleBlack */ -- 2.30.2