Add bag.member operator to theory of bags (#7857)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 4 Jan 2022 16:49:00 +0000 (10:49 -0600)
committerGitHub <noreply@github.com>
Tue, 4 Jan 2022 16:49:00 +0000 (16:49 +0000)
This PR adds the predicate bag.member to be analogous to predicate set.member.
The PR is motivated by converting regressions for sets to bags, which avoids defining a predicate for each set type

(define-fun bag.member ((e E) (B (Bag E))) Bool (>= (bag.count e B) 1))

13 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/kinds
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/bag_member.smt2 [new file with mode: 0644]

index eaea15b765200539ba2c35f19468b8b5234f5f3a..e794606aa450209f70da054be90a3e79bc370503 100644 (file)
@@ -304,6 +304,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
     {BAG_DIFFERENCE_REMOVE, cvc5::Kind::BAG_DIFFERENCE_REMOVE},
     {BAG_SUBBAG, cvc5::Kind::BAG_SUBBAG},
     {BAG_COUNT, cvc5::Kind::BAG_COUNT},
+    {BAG_MEMBER, cvc5::Kind::BAG_MEMBER},
     {BAG_DUPLICATE_REMOVAL, cvc5::Kind::BAG_DUPLICATE_REMOVAL},
     {BAG_MAKE, cvc5::Kind::BAG_MAKE},
     {BAG_EMPTY, cvc5::Kind::BAG_EMPTY},
@@ -616,6 +617,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::BAG_DIFFERENCE_REMOVE, BAG_DIFFERENCE_REMOVE},
         {cvc5::Kind::BAG_SUBBAG, BAG_SUBBAG},
         {cvc5::Kind::BAG_COUNT, BAG_COUNT},
+        {cvc5::Kind::BAG_MEMBER, BAG_MEMBER},
         {cvc5::Kind::BAG_DUPLICATE_REMOVAL, BAG_DUPLICATE_REMOVAL},
         {cvc5::Kind::BAG_MAKE, BAG_MAKE},
         {cvc5::Kind::BAG_EMPTY, BAG_EMPTY},
index e465a8faaba40549b648a1c0d0ac8f58b7e1c58d..9c885cb7be45f682cd08ce03ff26d4c009c5ee24 100644 (file)
@@ -2446,6 +2446,17 @@ enum Kind : int32_t
    *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
    */
   BAG_COUNT,
+  /**
+   * Bag membership predicate.
+   *
+   * Parameters:
+   *   - 1..2: Terms of bag sort (Bag E), is [1] of type E an element of [2]
+   *
+   * Create with:
+   *   - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) const`
+   *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+   */
+  BAG_MEMBER,
   /**
    * Eliminate duplicates in a given bag. The returned bag contains exactly the
    * same elements in the given bag, but with multiplicity one.
index 1fca42634535198f4cdd17cf2e5af50bc65bf67b..cf2db0179a6514ec496ee8dcaebe61e1d56bd20e 100644 (file)
@@ -621,6 +621,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::BAG_DIFFERENCE_REMOVE, "bag.difference_remove");
     addOperator(api::BAG_SUBBAG, "bag.subbag");
     addOperator(api::BAG_COUNT, "bag.count");
+    addOperator(api::BAG_MEMBER, "bag.member");
     addOperator(api::BAG_DUPLICATE_REMOVAL, "bag.duplicate_removal");
     addOperator(api::BAG_MAKE, "bag");
     addOperator(api::BAG_CARD, "bag.card");
index 69da5d03d85b4b545581a78ed449bbc40832389d..08c0482dafa290bb346c817f888293f93df5eedb 100644 (file)
@@ -1090,6 +1090,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_DIFFERENCE_REMOVE: return "bag.difference_remove";
   case kind::BAG_SUBBAG: return "bag.subbag";
   case kind::BAG_COUNT: return "bag.count";
+  case kind::BAG_MEMBER: return "bag.member";
   case kind::BAG_DUPLICATE_REMOVAL: return "bag.duplicate_removal";
   case kind::BAG_MAKE: return "bag";
   case kind::BAG_CARD: return "bag.card";
index f193bf73cdb6f576a0e5f7a4693e3255b766e904..40f8d6c95e9adfa4ba03b25b088c8a8769c2f6c7 100644 (file)
@@ -117,6 +117,7 @@ RewriteResponse BagsRewriter::preRewrite(TNode n)
   {
     case EQUAL: response = preRewriteEqual(n); break;
     case BAG_SUBBAG: response = rewriteSubBag(n); break;
+    case BAG_MEMBER: response = rewriteMember(n); break;
     default: response = BagsRewriteResponse(n, Rewrite::NONE);
   }
 
@@ -156,6 +157,16 @@ BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const
   return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
 }
 
+BagsRewriteResponse BagsRewriter::rewriteMember(const TNode& n) const
+{
+  Assert(n.getKind() == BAG_MEMBER);
+
+  // - (bag.member x A) = (>= (bag.count x A) 1)
+  Node count = d_nm->mkNode(BAG_COUNT, n[0], n[1]);
+  Node geq = d_nm->mkNode(GEQ, count, d_one);
+  return BagsRewriteResponse(geq, Rewrite::MEMBER);
+}
+
 BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
 {
   Assert(n.getKind() == BAG_MAKE);
index d666982a7d08c967112e375adb101b1801d1e75b..b4b1e90435db9cd7e0337594a530c385abef0b2c 100644 (file)
@@ -52,7 +52,7 @@ class BagsRewriter : public TheoryRewriter
    */
   RewriteResponse postRewrite(TNode n) override;
   /**
-   * preRewrite nodes with kinds: EQUAL, BGA_SUBBAG.
+   * preRewrite nodes with kinds: EQUAL, BAG_SUBBAG, BAG_MEMBER.
    * See the rewrite rules for these kinds below.
    */
   RewriteResponse preRewrite(TNode n) override;
@@ -70,6 +70,12 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse rewriteSubBag(const TNode& n) const;
 
+  /**
+   * rewrites for n include:
+   * - (bag.member x A) = (>= (bag.count x A) 1)
+   */
+  BagsRewriteResponse rewriteMember(const TNode& n) const;
+
   /**
    * rewrites for n include:
    * - (bag x 0) = (bag.empty T) where T is the type of x
index 5e4119fa19b05ddb257d3d844859cb23fc68277b..d83be5e211b913ea28d544b1dc4b79b77bf37654 100644 (file)
@@ -48,6 +48,7 @@ operator BAG_DIFFERENCE_REMOVE 2  "bag difference remove (removes shared element
 
 operator BAG_SUBBAG            2  "inclusion predicate for bags (less than or equal multiplicities)"
 operator BAG_COUNT             2  "multiplicity of an element in a bag"
+operator BAG_MEMBER            2  "bag membership predicate; is first parameter a member of second?"
 operator BAG_DUPLICATE_REMOVAL 1  "eliminate duplicates in a bag (also known as the delta operator,or the squash operator)"
 
 constant BAG_MAKE_OP \
@@ -91,6 +92,7 @@ typerule BAG_DIFFERENCE_SUBTRACT ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule BAG_DIFFERENCE_REMOVE   ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule BAG_SUBBAG              ::cvc5::theory::bags::SubBagTypeRule
 typerule BAG_COUNT               ::cvc5::theory::bags::CountTypeRule
+typerule BAG_MEMBER              ::cvc5::theory::bags::MemberTypeRule
 typerule BAG_DUPLICATE_REMOVAL   ::cvc5::theory::bags::DuplicateRemovalTypeRule
 typerule BAG_MAKE_OP             "SimpleTypeRule<RBuiltinOperator>"
 typerule BAG_MAKE                ::cvc5::theory::bags::BagMakeTypeRule
index 1a8f8f8491ea6cd36ed80cdcfcf38e109e8feb87..d8ed9fb959df82e514587a4f880ac7aa5627c599 100644 (file)
@@ -26,6 +26,7 @@ const char* toString(Rewrite r)
   switch (r)
   {
     case Rewrite::NONE: return "NONE";
+    case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE";
     case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT";
     case Rewrite::CARD_BAG_MAKE: return "CARD_BAG_MAKE";
     case Rewrite::CHOOSE_BAG_MAKE: return "CHOOSE_BAG_MAKE";
@@ -51,7 +52,7 @@ const char* toString(Rewrite r)
     case Rewrite::MAP_CONST: return "MAP_CONST";
     case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE";
     case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT";
-    case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE";
+    case Rewrite::MEMBER: return "MEMBER";
     case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
     case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
     case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT";
index 0b71885992557bf7cd22e05dbd57f5f5a4c4f511..57f10621114aa41ec5172014ab08abca5af8798f 100644 (file)
@@ -31,6 +31,7 @@ namespace bags {
 enum class Rewrite : uint32_t
 {
   NONE,  // no rewrite happened
+  BAG_MAKE_COUNT_NEGATIVE,
   CARD_DISJOINT,
   CARD_BAG_MAKE,
   CHOOSE_BAG_MAKE,
@@ -55,7 +56,7 @@ enum class Rewrite : uint32_t
   MAP_CONST,
   MAP_BAG_MAKE,
   MAP_UNION_DISJOINT,
-  BAG_MAKE_COUNT_NEGATIVE,
+  MEMBER,
   REMOVE_FROM_UNION,
   REMOVE_MIN,
   REMOVE_RETURN_LEFT,
index fe81fadf5781b114607532e3d44f4109a6b07239..2d218f8218af1665f4aeec8c86c93b9c72e61c82 100644 (file)
@@ -120,6 +120,35 @@ TypeNode CountTypeRule::computeType(NodeManager* nodeManager,
   return nodeManager->integerType();
 }
 
+TypeNode MemberTypeRule::computeType(NodeManager* nodeManager,
+                                     TNode n,
+                                     bool check)
+{
+  Assert(n.getKind() == kind::BAG_MEMBER);
+  TypeNode bagType = n[1].getType(check);
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n, "checking for membership in a non-bag");
+    }
+    TypeNode elementType = n[0].getType(check);
+    // e.g. (bag.member 1 (bag 1.0 1)) is true whereas
+    // (bag.member 1.0 (bag 1 1)) throws a typing error
+    if (!elementType.isSubtypeOf(bagType.getBagElementType()))
+    {
+      std::stringstream ss;
+      ss << "member operating on bags of different types:\n"
+         << "child type:  " << elementType << "\n"
+         << "not subtype: " << bagType.getBagElementType() << "\n"
+         << "in term : " << n;
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  return nodeManager->booleanType();
+}
+
 TypeNode DuplicateRemovalTypeRule::computeType(NodeManager* nodeManager,
                                                TNode n,
                                                bool check)
index fa2f7831315d605356e6591a8f1adcf8938d4e71..da9ea75bfe3c221bf9e146cf9863f0583fcd0816 100644 (file)
@@ -57,6 +57,15 @@ struct CountTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct CountTypeRule */
 
+/**
+ * Type rule for binary operator bag.member to check the sort of the first
+ * argument matches the element sort of the given bag.
+ */
+struct MemberTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+};
+
 /**
  * Type rule for bag.duplicate_removal to check the argument is of a bag.
  */
index 3c0e79596d20716722c85e04d368eac02aebc3a5..6009c020da784f7b4fec89eac5ca861462c41469 100644 (file)
@@ -1611,6 +1611,7 @@ set(regress_1_tests
   regress1/bug681.smt2
   regress1/bug694-Unapply1.scala-0.smt2
   regress1/bug800.smt2
+  regress1/bags/bag_member.smt2
   regress1/bags/bags-of-bags-subtypes.smt2
   regress1/bags/card1.smt2
   regress1/bags/card2.smt2
diff --git a/test/regress/regress1/bags/bag_member.smt2 b/test/regress/regress1/bags/bag_member.smt2
new file mode 100644 (file)
index 0000000..a2275ca
--- /dev/null
@@ -0,0 +1,5 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun B () (Bag String))
+(assert (bag.member "x" B))
+(check-sat)