Add bag.partition evaluation (#8637)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 21 Apr 2022 04:46:43 +0000 (23:46 -0500)
committerGitHub <noreply@github.com>
Thu, 21 Apr 2022 04:46:43 +0000 (23:46 -0500)
16 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/bags/kinds
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/bag_partition1.smt2 [new file with mode: 0644]

index d3c28aa06a807cea4eb094140ef5b5e6b2ba216d..84967b5c9c5fc583d187987b652d8e8b8576eb62 100644 (file)
@@ -328,6 +328,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(BAG_MAP, internal::Kind::BAG_MAP),
         KIND_ENUM(BAG_FILTER, internal::Kind::BAG_FILTER),
         KIND_ENUM(BAG_FOLD, internal::Kind::BAG_FOLD),
+        KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION),
         KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT),
         KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
         /* Strings ---------------------------------------------------------- */
@@ -644,6 +645,7 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::BAG_MAP, BAG_MAP},
         {internal::Kind::BAG_FILTER, BAG_FILTER},
         {internal::Kind::BAG_FOLD, BAG_FOLD},
+        {internal::Kind::BAG_PARTITION, BAG_PARTITION},
         {internal::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
         {internal::Kind::TABLE_PROJECT, TABLE_PROJECT},
         {internal::Kind::TABLE_PROJECT_OP, TABLE_PROJECT},
index bbc5cfdd8bf65af895a8ee18de504425fcaed07d..8ee2f378c0996f5a141fd012d9db48f18480de73 100644 (file)
@@ -3673,6 +3673,33 @@ enum Kind : int32_t
    * \endrst
    */
   BAG_FOLD,
+  /**
+   * Bag partition.
+   *
+   * \rst
+   * This operator partitions of a bag of elements into disjoint bags.
+   * (bag.partition :math:`r \; B`) partitions the elements of bag :math:`B`
+   * of type :math:`(Bag \; E)` based on the equivalence relations :math:`r` of
+   * type :math:`(\rightarrow \; E \; E \; Bool)`.
+   * It returns a bag of bags of type :math:`(Bag \; (Bag \; E))`.
+   *
+   * - Arity: ``2``
+   *
+   *   - ``1:`` Term of function Sort :math:`(\rightarrow \; E \; E \; Bool)`
+   *   - ``2:`` Term of bag Sort (Bag :math:`E`)
+   * \endrst
+   *
+   * - Create Term of this Kind with:
+   *
+   *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  BAG_PARTITION,
   /**
    * Table cross product.
    *
index 93a518df0d3697d2bb742ab5f42ee9f1bb3b11ad..a4a16c214ad00ff700165585708a727196b3f351 100644 (file)
@@ -634,6 +634,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(cvc5::BAG_MAP, "bag.map");
     addOperator(cvc5::BAG_FILTER, "bag.filter");
     addOperator(cvc5::BAG_FOLD, "bag.fold");
+    addOperator(cvc5::BAG_PARTITION, "bag.partition");
     addOperator(cvc5::TABLE_PRODUCT, "table.product");
   }
   if (d_logic.isTheoryEnabled(internal::theory::THEORY_STRINGS))
index bc35f639f480ced995670b22a63b7d2d4892d99b..41fa39575d728b707ac82e20299c419683d11e09 100644 (file)
@@ -1183,6 +1183,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_MAP: return "bag.map";
   case kind::BAG_FILTER: return "bag.filter";
   case kind::BAG_FOLD: return "bag.fold";
+  case kind::BAG_PARTITION: return "bag.partition";
   case kind::TABLE_PRODUCT: return "table.product";
   case kind::TABLE_PROJECT: return "table.project";
 
index cdf2dde0250dd70eccf8e0aceb897bca16d665fd..6b7e49a31d3250a0061ebbdf4c24e44bccd80836 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "expr/emptybag.h"
 #include "theory/bags/bags_utils.h"
+#include "theory/rewriter.h"
 #include "util/rational.h"
 #include "util/statistics_registry.h"
 
@@ -41,8 +42,8 @@ BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
 {
 }
 
-BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
-    : d_statistics(statistics)
+BagsRewriter::BagsRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics)
+    : d_rewriter(r), d_statistics(statistics)
 {
   d_nm = NodeManager::currentNM();
   d_zero = d_nm->mkConstInt(Rational(0));
@@ -92,6 +93,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_MAP: response = postRewriteMap(n); break;
       case BAG_FILTER: response = postRewriteFilter(n); break;
       case BAG_FOLD: response = postRewriteFold(n); break;
+      case BAG_PARTITION: response = postRewritePartition(n); break;
       case TABLE_PRODUCT: response = postRewriteProduct(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
@@ -648,6 +650,21 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
 
+BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
+{
+  Assert(n.getKind() == kind::BAG_PARTITION);
+  if (n[1].isConst())
+  {
+    Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n);
+    if (ret != n)
+    {
+      return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST);
+    }
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
 BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
 {
   Assert(n.getKind() == TABLE_PRODUCT);
index 00c8b6d0c479c9bacb5b94fbc91d4bc8036e0b85..3c08208a80edad56f6b7b62c3c9f999be9dc9f1f 100644 (file)
@@ -42,7 +42,7 @@ struct BagsRewriteResponse
 class BagsRewriter : public TheoryRewriter
 {
  public:
-  BagsRewriter(HistogramStat<Rewrite>* statistics = nullptr);
+  BagsRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics = nullptr);
 
   /**
    * postRewrite nodes with kinds: BAG_MAKE, BAG_COUNT, BAG_UNION_MAX,
@@ -246,6 +246,7 @@ class BagsRewriter : public TheoryRewriter
    *  where f: T1 -> T2 -> T2
    */
   BagsRewriteResponse postRewriteFold(const TNode& n) const;
+  BagsRewriteResponse postRewritePartition(const TNode& n) const;
   /**
    *  rewrites for n include:
    *  - (bag.product A (as bag.empty T2)) = (as bag.empty T)
@@ -262,6 +263,11 @@ class BagsRewriter : public TheoryRewriter
   NodeManager* d_nm;
   Node d_zero;
   Node d_one;
+  /**
+   * Pointer to the rewriter. NOTE this is a cyclic dependency, and should
+   * be removed.
+   */
+  Rewriter* d_rewriter;
   /** Reference to the rewriter statistics. */
   HistogramStat<Rewrite>* d_statistics;
 }; /* class TheoryBagsRewriter */
index e719232486101bab5c6c79098b8e53cdc28719f8..fd5a98c25aebf6d5d02a31990ad04d8d47f644a2 100644 (file)
 #include "smt/logic_exception.h"
 #include "table_project_op.h"
 #include "theory/datatypes/tuple_utils.h"
+#include "theory/rewriter.h"
 #include "theory/sets/normal_form.h"
 #include "theory/type_enumerator.h"
+#include "theory/uf/equality_engine.h"
 #include "util/rational.h"
 
 using namespace cvc5::internal::kind;
@@ -785,6 +787,98 @@ Node BagsUtils::evaluateBagFold(TNode n)
   return ret;
 }
 
+Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n)
+{
+  Assert(n.getKind() == BAG_PARTITION);
+  NodeManager* nm = NodeManager::currentNM();
+
+  // Examples
+  // --------
+  // minimum string
+  // - (bag.partition
+  //     ((lambda ((x Int) (y Int)) (= 0 (+ x y)))
+  //     (bag.union_disjoint
+  //       (bag 1 20) (bag (- 1) 50)
+  //       (bag 2 30) (bag (- 2) 60)
+  //       (bag 3 40) (bag (- 3) 70)
+  //       (bag 4 100)))
+  //   = (bag.union_disjoint
+  //       (bag (bag 4 100) 1)
+  //       (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1)
+  //       (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1)
+  //       (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1)))
+
+  Node r = n[0];  // equivalence relation
+  Node A = n[1];  // bag
+  TypeNode bagType = A.getType();
+  TypeNode partitionType = n.getType();
+  std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
+  Trace("bags-partition") << "elements: " << elements << std::endl;
+  // a simple map from elements to equivalent classes with this invariant:
+  // each key element must appear exactly once in one of the values.
+  std::map<Node, std::set<Node>> sets;
+  std::set<Node> emptyClass;
+  for (const auto& pair : elements)
+  {
+    // initially each singleton element is an equivalence class
+    sets[pair.first] = {pair.first};
+  }
+  for (std::map<Node, Rational>::iterator i = elements.begin();
+       i != elements.end();
+       ++i)
+  {
+    if (sets[i->first].empty())
+    {
+      // skip this element since its equivalent class has already been processed
+      continue;
+    }
+    std::map<Node, Rational>::iterator j = i;
+    ++j;
+    while (j != elements.end())
+    {
+      Node sameClass = nm->mkNode(APPLY_UF, r, i->first, j->first);
+      sameClass = rewriter->rewrite(sameClass);
+      if (!sameClass.isConst())
+      {
+        // we can not pursue further, so we return n itself
+        return n;
+      }
+      if (sameClass.getConst<bool>())
+      {
+        // add element j to the equivalent class
+        sets[i->first].insert(j->first);
+        // mark the equivalent class of j as processed
+        sets[j->first] = emptyClass;
+      }
+      ++j;
+    }
+  }
+
+  // construct the partition parts
+  std::map<Node, Rational> parts;
+  for (std::pair<Node, std::set<Node>> pair : sets)
+  {
+    const std::set<Node>& eqc = pair.second;
+    if (eqc.empty())
+    {
+      continue;
+    }
+    std::vector<Node> bags;
+    for (const Node& node : eqc)
+    {
+      Node bag = nm->mkBag(
+          bagType.getBagElementType(), node, nm->mkConstInt(elements[node]));
+      bags.push_back(bag);
+    }
+    Node part = computeDisjointUnion(bagType, bags);
+    // each part in the partitions has multiplicity one
+    parts[part] = Rational(1);
+  }
+  Node ret = constructConstantBagFromElements(partitionType, parts);
+  Trace("bags-partition") << "ret: " << ret << std::endl;
+  return ret;
+}
+
 Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
 {
   Assert(n.getKind() == TABLE_PRODUCT);
index 42e7b0caf066d7b35827a4010c49d3db746dbd7b..21de8e959e53dee35d039c9eaf9fbb2318e1f32e 100644 (file)
@@ -20,6 +20,8 @@
 #ifndef CVC5__THEORY__BAGS__UTILS_H
 #define CVC5__THEORY__BAGS__UTILS_H
 
+#include "theory/theory_rewriter.h"
+
 namespace cvc5::internal {
 namespace theory {
 namespace bags {
@@ -88,6 +90,12 @@ class BagsUtils
    */
   static Node evaluateBagFold(TNode n);
 
+  /**
+   * @param n has the form (bag.partition r A) where A is a constant bag
+   * @return a partition of A based on the equivalence relation r
+   */
+  static Node evaluateBagPartition(Rewriter *rewriter, TNode n);
+
   /**
    * @param n has the form (bag.filter p A) where A is a constant bag
    * @return A filtered with predicate p
index 49bca83fbbddab2cb161bc9cb4db0d2ab0b571a6..1e875e99880738ab921d0a8f4dad25ac10e570a4 100644 (file)
@@ -89,6 +89,10 @@ operator BAG_FILTER        2  "bag filter operator"
 #  B: a bag of type (Bag T1)
 operator BAG_FOLD          3  "bag fold operator"
 
+# bag.partition operator partitions a bag into a bag of bags based on an equivalence relation such that
+# each element occurs exactly in one these bags.
+operator BAG_PARTITION     2 "bag partition operator"
+
 typerule BAG_UNION_MAX           ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
 typerule BAG_UNION_DISJOINT      ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
 typerule BAG_INTER_MIN           ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
@@ -109,6 +113,7 @@ typerule BAG_TO_SET              ::cvc5::internal::theory::bags::ToSetTypeRule
 typerule BAG_MAP                 ::cvc5::internal::theory::bags::BagMapTypeRule
 typerule BAG_FILTER              ::cvc5::internal::theory::bags::BagFilterTypeRule
 typerule BAG_FOLD                ::cvc5::internal::theory::bags::BagFoldTypeRule
+typerule BAG_PARTITION           ::cvc5::internal::theory::bags::BagPartitionTypeRule
 
 construle BAG_UNION_DISJOINT     ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
 construle BAG_MAKE               ::cvc5::internal::theory::bags::BagMakeTypeRule
index 287e879a7d7add8b92927fce2c54683dfcc6dd11..0c634351af1034484c2725ca27b8f528c73891ed 100644 (file)
@@ -56,6 +56,7 @@ const char* toString(Rewrite r)
     case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE";
     case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT";
     case Rewrite::MEMBER: return "MEMBER";
+    case Rewrite::PARTITION_CONST: return "PARTITION_CONST";
     case Rewrite::PRODUCT_EMPTY: return "PRODUCT_EMPTY";
     case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
     case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
index 467de36db88d31d5986ac35250504f146ed8f157..461ea8703d3df36296d770c60772b4d835362821 100644 (file)
@@ -60,6 +60,7 @@ enum class Rewrite : uint32_t
   MAP_BAG_MAKE,
   MAP_UNION_DISJOINT,
   MEMBER,
+  PARTITION_CONST,
   PRODUCT_EMPTY,
   REMOVE_FROM_UNION,
   REMOVE_MIN,
index 92ea5eccaa6cf0c3d8557247a92356d5b8d2a480..adcf3d468c1b6849cc1c30fa6db10c19cb491e73 100644 (file)
@@ -38,7 +38,7 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation)
       d_ig(&d_state, &d_im),
       d_notify(*this, d_im),
       d_statistics(),
-      d_rewriter(&d_statistics.d_rewrites),
+      d_rewriter(env.getRewriter(), &d_statistics.d_rewrites),
       d_termReg(env, d_state, d_im),
       d_solver(env, d_state, d_im, d_termReg),
       d_cardSolver(env, d_state, d_im),
@@ -80,6 +80,7 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(BAG_CARD);
   d_equalityEngine->addFunctionKind(BAG_FROM_SET);
   d_equalityEngine->addFunctionKind(BAG_TO_SET);
+  d_equalityEngine->addFunctionKind(BAG_PARTITION);
   d_equalityEngine->addFunctionKind(TABLE_PRODUCT);
   d_equalityEngine->addFunctionKind(TABLE_PROJECT);
 }
@@ -455,6 +456,7 @@ void TheoryBags::preRegisterTerm(TNode n)
     case BAG_FROM_SET:
     case BAG_TO_SET:
     case BAG_IS_SINGLETON:
+    case BAG_PARTITION:
     case TABLE_PROJECT:
     {
       std::stringstream ss;
index ef2a5a35002bb59f1fd0918a45e4aee2bfc47cea..e786a6afc68818e0567e417639bfcd9fa1f575d0 100644 (file)
@@ -454,6 +454,51 @@ TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode BagPartitionTypeRule::computeType(NodeManager* nodeManager,
+                                           TNode n,
+                                           bool check)
+{
+  Assert(n.getKind() == kind::BAG_PARTITION);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode bagType = n[1].getType(check);
+  NodeManager* nm = NodeManager::currentNM();
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "bag.partition operator expects a bag in the second argument, "
+          "a non-bag is found");
+    }
+
+    TypeNode elementType = bagType.getBagElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " " << elementType << " Bool) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    TypeNode rangeType = functionType.getRangeType();
+    if (!(argTypes.size() == 2 && elementType.isSubtypeOf(argTypes[0])
+          && elementType.isSubtypeOf(argTypes[1])
+          && rangeType == nm->booleanType()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " " << elementType << " Bool) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  TypeNode retType = nm->mkBagType(bagType);
+  return retType;
+}
+
 TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
                                            TNode n,
                                            bool check)
index 54329b405de0ff2e1069aaa95462a1601c91be95..04e5bfd04f8b295e8f8917bcb0fc29e3f9339fce 100644 (file)
@@ -152,13 +152,22 @@ struct BagFilterTypeRule
 
 /**
  * Type rule for (bag.fold f t A) to make sure f is a binary operation of type
- * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
+ * (-> T1 T2 T2), t of type T2, and A is a bag of type (Bag T1)
  */
 struct BagFoldTypeRule
 {
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagFoldTypeRule */
 
+/**
+ * Type rule for (bag.partition r A) to make sure r is a binary operation of type
+ * (-> T1 T1 Bool), and A is a bag of type (Bag T1)
+ */
+struct BagPartitionTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
 /**
  * Type rule for (table.product A B) to make sure A,B are bags of tuples,
  * and get the type of the cross product
index 59891768a813fc74295ba0079d897b953dde76c6..ca74fc74f9df08627bcbb7cb3d70be253b41e357 100644 (file)
@@ -1753,6 +1753,7 @@ set(regress_1_tests
   regress1/bug694-Unapply1.scala-0.smt2
   regress1/bug800.smt2
   regress1/bags/bag_member.smt2
+  regress1/bags/bag_partition1.smt2
   regress1/bags/bags-of-bags-subtypes.smt2
   regress1/bags/card1.smt2
   regress1/bags/card2.smt2
diff --git a/test/regress/cli/regress1/bags/bag_partition1.smt2 b/test/regress/cli/regress1/bags/bag_partition1.smt2
new file mode 100644 (file)
index 0000000..84c7323
--- /dev/null
@@ -0,0 +1,34 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+; equivalence relation : inverse
+(define-fun r ((x Int) (y Int)) Bool (= 0 (+ x y)))
+
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag (Bag Int)))
+(declare-fun C () (Bag (Bag Int)))
+
+(assert
+ (= A
+    (bag.union_disjoint
+     (bag 1 20) (bag (- 1) 50)
+     (bag 2 30) (bag (- 2) 60)
+     (bag 3 40) (bag (- 3) 70)
+     (bag 4 100))))
+
+;(define-fun B () (Bag (Bag Int))
+;  (bag.union_disjoint (bag (bag 4 100) 1)
+;                      (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1)
+;                      (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1)
+;                      (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1)))
+
+(assert (= B (bag.partition r A)))
+; (define-fun C () (Bag (Bag Int)) (as bag.empty (Bag (Bag Int))))
+(assert (= C (bag.partition r (as bag.empty (Bag Int)))))
+
+(check-sat)
+