From 70f007e3fbf76d47aa52b71a10f24a189311c945 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Tue, 4 Jan 2022 10:49:00 -0600 Subject: [PATCH] Add bag.member operator to theory of bags (#7857) This PR adds the predicate bag.member to be analogous to predicate set.member. The PR is motivated by converting regressions for sets to bags, which avoids defining a predicate for each set type (define-fun bag.member ((e E) (B (Bag E))) Bool (>= (bag.count e B) 1)) --- src/api/cpp/cvc5.cpp | 2 ++ src/api/cpp/cvc5_kind.h | 11 ++++++++ src/parser/smt2/smt2.cpp | 1 + src/printer/smt2/smt2_printer.cpp | 1 + src/theory/bags/bags_rewriter.cpp | 11 ++++++++ src/theory/bags/bags_rewriter.h | 8 +++++- src/theory/bags/kinds | 2 ++ src/theory/bags/rewrites.cpp | 3 ++- src/theory/bags/rewrites.h | 3 ++- src/theory/bags/theory_bags_type_rules.cpp | 29 ++++++++++++++++++++++ src/theory/bags/theory_bags_type_rules.h | 9 +++++++ test/regress/CMakeLists.txt | 1 + test/regress/regress1/bags/bag_member.smt2 | 5 ++++ 13 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 test/regress/regress1/bags/bag_member.smt2 diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index eaea15b76..e794606aa 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -304,6 +304,7 @@ const static std::unordered_map s_kinds{ {BAG_DIFFERENCE_REMOVE, cvc5::Kind::BAG_DIFFERENCE_REMOVE}, {BAG_SUBBAG, cvc5::Kind::BAG_SUBBAG}, {BAG_COUNT, cvc5::Kind::BAG_COUNT}, + {BAG_MEMBER, cvc5::Kind::BAG_MEMBER}, {BAG_DUPLICATE_REMOVAL, cvc5::Kind::BAG_DUPLICATE_REMOVAL}, {BAG_MAKE, cvc5::Kind::BAG_MAKE}, {BAG_EMPTY, cvc5::Kind::BAG_EMPTY}, @@ -616,6 +617,7 @@ const static std::unordered_map {cvc5::Kind::BAG_DIFFERENCE_REMOVE, BAG_DIFFERENCE_REMOVE}, {cvc5::Kind::BAG_SUBBAG, BAG_SUBBAG}, {cvc5::Kind::BAG_COUNT, BAG_COUNT}, + {cvc5::Kind::BAG_MEMBER, BAG_MEMBER}, {cvc5::Kind::BAG_DUPLICATE_REMOVAL, BAG_DUPLICATE_REMOVAL}, {cvc5::Kind::BAG_MAKE, BAG_MAKE}, {cvc5::Kind::BAG_EMPTY, BAG_EMPTY}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index e465a8faa..9c885cb7b 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2446,6 +2446,17 @@ enum Kind : int32_t * - `Solver::mkTerm(Kind kind, const std::vector& children) const` */ BAG_COUNT, + /** + * Bag membership predicate. + * + * Parameters: + * - 1..2: Terms of bag sort (Bag E), is [1] of type E an element of [2] + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + BAG_MEMBER, /** * Eliminate duplicates in a given bag. The returned bag contains exactly the * same elements in the given bag, but with multiplicity one. diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 1fca42634..cf2db0179 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -621,6 +621,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_DIFFERENCE_REMOVE, "bag.difference_remove"); addOperator(api::BAG_SUBBAG, "bag.subbag"); addOperator(api::BAG_COUNT, "bag.count"); + addOperator(api::BAG_MEMBER, "bag.member"); addOperator(api::BAG_DUPLICATE_REMOVAL, "bag.duplicate_removal"); addOperator(api::BAG_MAKE, "bag"); addOperator(api::BAG_CARD, "bag.card"); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 69da5d03d..08c0482da 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1090,6 +1090,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_DIFFERENCE_REMOVE: return "bag.difference_remove"; case kind::BAG_SUBBAG: return "bag.subbag"; case kind::BAG_COUNT: return "bag.count"; + case kind::BAG_MEMBER: return "bag.member"; case kind::BAG_DUPLICATE_REMOVAL: return "bag.duplicate_removal"; case kind::BAG_MAKE: return "bag"; case kind::BAG_CARD: return "bag.card"; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index f193bf73c..40f8d6c95 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -117,6 +117,7 @@ RewriteResponse BagsRewriter::preRewrite(TNode n) { case EQUAL: response = preRewriteEqual(n); break; case BAG_SUBBAG: response = rewriteSubBag(n); break; + case BAG_MEMBER: response = rewriteMember(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); } @@ -156,6 +157,16 @@ BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const return BagsRewriteResponse(equal, Rewrite::SUB_BAG); } +BagsRewriteResponse BagsRewriter::rewriteMember(const TNode& n) const +{ + Assert(n.getKind() == BAG_MEMBER); + + // - (bag.member x A) = (>= (bag.count x A) 1) + Node count = d_nm->mkNode(BAG_COUNT, n[0], n[1]); + Node geq = d_nm->mkNode(GEQ, count, d_one); + return BagsRewriteResponse(geq, Rewrite::MEMBER); +} + BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const { Assert(n.getKind() == BAG_MAKE); diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index d666982a7..b4b1e9043 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -52,7 +52,7 @@ class BagsRewriter : public TheoryRewriter */ RewriteResponse postRewrite(TNode n) override; /** - * preRewrite nodes with kinds: EQUAL, BGA_SUBBAG. + * preRewrite nodes with kinds: EQUAL, BAG_SUBBAG, BAG_MEMBER. * See the rewrite rules for these kinds below. */ RewriteResponse preRewrite(TNode n) override; @@ -70,6 +70,12 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse rewriteSubBag(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.member x A) = (>= (bag.count x A) 1) + */ + BagsRewriteResponse rewriteMember(const TNode& n) const; + /** * rewrites for n include: * - (bag x 0) = (bag.empty T) where T is the type of x diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 5e4119fa1..d83be5e21 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -48,6 +48,7 @@ operator BAG_DIFFERENCE_REMOVE 2 "bag difference remove (removes shared element operator BAG_SUBBAG 2 "inclusion predicate for bags (less than or equal multiplicities)" operator BAG_COUNT 2 "multiplicity of an element in a bag" +operator BAG_MEMBER 2 "bag membership predicate; is first parameter a member of second?" operator BAG_DUPLICATE_REMOVAL 1 "eliminate duplicates in a bag (also known as the delta operator,or the squash operator)" constant BAG_MAKE_OP \ @@ -91,6 +92,7 @@ typerule BAG_DIFFERENCE_SUBTRACT ::cvc5::theory::bags::BinaryOperatorTypeRule typerule BAG_DIFFERENCE_REMOVE ::cvc5::theory::bags::BinaryOperatorTypeRule typerule BAG_SUBBAG ::cvc5::theory::bags::SubBagTypeRule typerule BAG_COUNT ::cvc5::theory::bags::CountTypeRule +typerule BAG_MEMBER ::cvc5::theory::bags::MemberTypeRule typerule BAG_DUPLICATE_REMOVAL ::cvc5::theory::bags::DuplicateRemovalTypeRule typerule BAG_MAKE_OP "SimpleTypeRule" typerule BAG_MAKE ::cvc5::theory::bags::BagMakeTypeRule diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index 1a8f8f849..d8ed9fb95 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -26,6 +26,7 @@ const char* toString(Rewrite r) switch (r) { case Rewrite::NONE: return "NONE"; + case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE"; case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT"; case Rewrite::CARD_BAG_MAKE: return "CARD_BAG_MAKE"; case Rewrite::CHOOSE_BAG_MAKE: return "CHOOSE_BAG_MAKE"; @@ -51,7 +52,7 @@ const char* toString(Rewrite r) case Rewrite::MAP_CONST: return "MAP_CONST"; case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE"; case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT"; - case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE"; + case Rewrite::MEMBER: return "MEMBER"; case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION"; case Rewrite::REMOVE_MIN: return "REMOVE_MIN"; case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index 0b7188599..57f106211 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -31,6 +31,7 @@ namespace bags { enum class Rewrite : uint32_t { NONE, // no rewrite happened + BAG_MAKE_COUNT_NEGATIVE, CARD_DISJOINT, CARD_BAG_MAKE, CHOOSE_BAG_MAKE, @@ -55,7 +56,7 @@ enum class Rewrite : uint32_t MAP_CONST, MAP_BAG_MAKE, MAP_UNION_DISJOINT, - BAG_MAKE_COUNT_NEGATIVE, + MEMBER, REMOVE_FROM_UNION, REMOVE_MIN, REMOVE_RETURN_LEFT, diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index fe81fadf5..2d218f821 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -120,6 +120,35 @@ TypeNode CountTypeRule::computeType(NodeManager* nodeManager, return nodeManager->integerType(); } +TypeNode MemberTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_MEMBER); + TypeNode bagType = n[1].getType(check); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, "checking for membership in a non-bag"); + } + TypeNode elementType = n[0].getType(check); + // e.g. (bag.member 1 (bag 1.0 1)) is true whereas + // (bag.member 1.0 (bag 1 1)) throws a typing error + if (!elementType.isSubtypeOf(bagType.getBagElementType())) + { + std::stringstream ss; + ss << "member operating on bags of different types:\n" + << "child type: " << elementType << "\n" + << "not subtype: " << bagType.getBagElementType() << "\n" + << "in term : " << n; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return nodeManager->booleanType(); +} + TypeNode DuplicateRemovalTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index fa2f78313..da9ea75bf 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -57,6 +57,15 @@ struct CountTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct CountTypeRule */ +/** + * Type rule for binary operator bag.member to check the sort of the first + * argument matches the element sort of the given bag. + */ +struct MemberTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; + /** * Type rule for bag.duplicate_removal to check the argument is of a bag. */ diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 3c0e79596..6009c020d 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1611,6 +1611,7 @@ set(regress_1_tests regress1/bug681.smt2 regress1/bug694-Unapply1.scala-0.smt2 regress1/bug800.smt2 + regress1/bags/bag_member.smt2 regress1/bags/bags-of-bags-subtypes.smt2 regress1/bags/card1.smt2 regress1/bags/card2.smt2 diff --git a/test/regress/regress1/bags/bag_member.smt2 b/test/regress/regress1/bags/bag_member.smt2 new file mode 100644 index 000000000..a2275caea --- /dev/null +++ b/test/regress/regress1/bags/bag_member.smt2 @@ -0,0 +1,5 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun B () (Bag String)) +(assert (bag.member "x" B)) +(check-sat) -- 2.30.2