Add table.group operator (#8731)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 24 May 2022 14:51:49 +0000 (09:51 -0500)
committerGitHub <noreply@github.com>
Tue, 24 May 2022 14:51:49 +0000 (14:51 +0000)
22 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_reduction.cpp
src/theory/bags/bag_reduction.h
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/bags/kinds
src/theory/bags/solver_state.cpp
src/theory/bags/table_project_op.cpp
src/theory/bags/table_project_op.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/datatypes/tuple_utils.cpp
src/theory/datatypes/tuple_utils.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/table_group1.smt2 [new file with mode: 0644]

index 980e6468b4131b4c3fb9b7657e2c10781e2079e0..b6c8fc95d405a9081bef5660623193c688ea74c4 100644 (file)
@@ -334,6 +334,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
         KIND_ENUM(TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE),
         KIND_ENUM(TABLE_JOIN, internal::Kind::TABLE_JOIN),
+        KIND_ENUM(TABLE_GROUP, internal::Kind::TABLE_GROUP),
         /* Strings ---------------------------------------------------------- */
         KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT),
         KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP),
@@ -657,6 +658,8 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::TABLE_AGGREGATE, TABLE_AGGREGATE},
         {internal::Kind::TABLE_JOIN_OP, TABLE_JOIN},
         {internal::Kind::TABLE_JOIN, TABLE_JOIN},
+        {internal::Kind::TABLE_GROUP_OP, TABLE_GROUP},
+        {internal::Kind::TABLE_GROUP, TABLE_GROUP},
         /* Strings --------------------------------------------------------- */
         {internal::Kind::STRING_CONCAT, STRING_CONCAT},
         {internal::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
@@ -6116,6 +6119,9 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
     case TABLE_JOIN:
       res = mkOpHelper(kind, internal::TableJoinOp(args));
       break;
+    case TABLE_GROUP:
+      res = mkOpHelper(kind, internal::TableGroupOp(args));
+      break;
     default:
       if (nargs == 0)
       {
index 2e1fc435fd78b09a36929a7cd670aa4b2a4614cf..40797e027718fa78fe37710c71287117775685ff 100644 (file)
@@ -3743,16 +3743,22 @@ enum Kind : int32_t
    *
    * - Create Op of this kind with:
    *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
    */
   TABLE_PROJECT,
   /**
+   * \rst
+   *
    * Table aggregate operator has the form
-   * :math:`((\_ \; table.aggr \; n_1 ... n_k) f i A)`
+   * :math:`((\_ \; table.aggr \; n_1 ... n_k) \; f \; i \; A)`
    * where :math:`n_1, ..., n_k` are natural numbers,
    * :math:`f` is a function of type
    * :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`,
    * :math:`i` has the type :math:`T`,
-   * and :math`A` has type :math:`Table T_1 ... T_j`.
+   * and :math:`A` has type :math:`(Table \;  T_1 \; ... \; T_j)`.
    * The returned type is :math:`(Bag \; T)`.
    *
    * This operator aggregates elements in A that have the same tuple projection
@@ -3760,43 +3766,89 @@ enum Kind : int32_t
    * and initial value :math:`i`.
    *
    * - Arity: ``3``
+   *
    *   - ``1:`` Term of sort :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`
    *   - ``2:`` Term of Sort :math:`T`
    *   - ``3:`` Term of table sort :math:`Table T_1 ... T_j`
    *
    * - Indices: ``n``
    *   - ``1..n:`` Indices of the projection
-   *
+   * \endrst
    * - Create Term of this Kind with:
    *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
    *
    * - Create Op of this kind with:
    *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
    */
   TABLE_AGGREGATE,
   /**
-   * Table join operator has the form
+   * \rst
+   *  Table join operator has the form
    *  :math:`((\_ \; table.join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)`
-   *  where *  :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers,
-   *  and A, B are tables.
+   *  where :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers,
+   *  and :math:`A, B` are tables.
    *  This operator filters the product of two bags based on the equality of
-   *  projected tuples using indices :math:`m_1, \dots, m_k` in table A,
-   *  and indices :math:`n_1, \dots, n_k` in table B.
+   *  projected tuples using indices :math:`m_1, \dots, m_k` in table :math:`A`,
+   *  and indices :math:`n_1, \dots, n_k` in table :math:`B`.
    *
    * - Arity: ``2``
+   *
    *   - ``1:`` Term of table Sort
+   *
    *   - ``2:`` Term of table Sort
    *
    * - Indices: ``n``
    *   - ``1..n:``  Indices of the projection
    *
+   * \endrst
    * - Create Term of this Kind with:
    *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
    *
    * - Create Op of this kind with:
    *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
    */
   TABLE_JOIN,
+  /**
+   * Table group
+   *
+   * \rst
+   * :math:`((\_ \; table.group \; n_1 \; \dots \; n_k) \; A)` partitions tuples
+   * of table :math:`A` such that tuples that have the same projection
+   * with indices :math:`n_1 \; \dots \; n_k` are in the same part.
+   * It returns a bag of tables of type :math:`(Bag \; T)` where
+   * :math:`T` is the type of :math:`A`.
+   *
+   * - Arity: ``1``
+   *
+   *   - ``1:`` Term of table sort
+   *
+   * - Indices: ``n``
+   *
+   *   - ``1..n:``  Indices of the projection
+   *
+   * \endrst
+   *
+   * - Create Term of this Kind with:
+   *
+   *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  TABLE_GROUP,
   /* Strings --------------------------------------------------------------- */
 
   /**
index 993d9b4d8bfb0951d9409447292e25666a8312be..58276262bf7c0b23f12da9084ab0dd4ed0f4b341 100644 (file)
@@ -94,7 +94,7 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::BAGS_MAP_PREIMAGE_SIZE: return "BAGS_MAP_PREIMAGE_SIZE";
     case SkolemFunId::BAGS_MAP_PREIMAGE_INDEX: return "BAGS_MAP_PREIMAGE_INDEX";
     case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
-    case SkolemFunId::BAG_DEQ_DIFF: return "BAG_DEQ_DIFF";
+    case SkolemFunId::BAGS_DEQ_DIFF: return "BAGS_DEQ_DIFF";
     case SkolemFunId::SETS_CHOOSE: return "SETS_CHOOSE";
     case SkolemFunId::SETS_DEQ_DIFF: return "SETS_DEQ_DIFF";
     case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
index 2480d1d0c962671ad9bde07f1bc2161a61097767..9d4f225f9914f812c483985743287eb9680a54f0 100644 (file)
@@ -168,7 +168,7 @@ enum class SkolemFunId
    */
   BAGS_MAP_SUM,
   /** bag diff to witness (not (= A B)) */
-  BAG_DEQ_DIFF,
+  BAGS_DEQ_DIFF,
   /** An interpreted function for bag.choose operator:
    * (choose A) is expanded as
    * (witness ((x elementType))
index 239d36468fe35d92178890613716f55e23b15daa..ee0738b114233832ad9da3a1a7925f5668f7ab9a 100644 (file)
@@ -1425,6 +1425,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_JOIN, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK TABLE_GROUP_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_GROUP, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1587,6 +1593,13 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::TABLE_JOIN;
         p.d_op = SOLVER->mkOp(cvc5::TABLE_JOIN, numerals);
       }
+     | TABLE_GROUP_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ table.group i_1 ... i_n) where
+        // i_1, ..., j_n are numerals
+        p.d_kind = cvc5::TABLE_GROUP;
+        p.d_op = SOLVER->mkOp(cvc5::TABLE_GROUP, numerals);
+      }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2198,6 +2211,7 @@ TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATA
 TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project';
 TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr';
 TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join';
+TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index cf28d0ac8cafbb5766e767a82b771f9d06d2ec77..66f2db214a2db5538e41b248130b6f4358bd6c77 100644 (file)
@@ -636,6 +636,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(cvc5::BAG_FOLD, "bag.fold");
     addOperator(cvc5::BAG_PARTITION, "bag.partition");
     addOperator(cvc5::TABLE_PRODUCT, "table.product");
+    addOperator(cvc5::BAG_PARTITION, "table.group");
   }
   if (d_logic.isTheoryEnabled(internal::theory::THEORY_STRINGS))
   {
@@ -1126,7 +1127,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
     return ret;
   }
   else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
-           || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN)
+           || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN
+           || p.d_kind == cvc5::TABLE_GROUP)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index ec29595e7c7bdc4cb6246d4d19f7b2015d8c9598..2790463003b01f31cf4d349a6f0b7bf99f7617bb 100644 (file)
@@ -792,6 +792,21 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::TABLE_GROUP:
+  {
+    TableGroupOp op = n.getOperator().getConst<TableGroupOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (table.group A)
+      out << "table.group " << n[0] << ")";
+    }
+    else
+    {
+      // e.g. ((_ table.group 0 1 2 3) A)
+      out << "(_ table.group" << op << ") " << n[0] << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -1155,6 +1170,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::TABLE_PROJECT: return "table.project";
   case kind::TABLE_AGGREGATE: return "table.aggr";
   case kind::TABLE_JOIN: return "table.join";
+  case kind::TABLE_GROUP: return "table.group";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
index 89a4223010b1e5d0956e761fcaf0a12fccb1d920..e7ccba325b70fb1991e1c017146803b7a6babb4c 100644 (file)
@@ -214,28 +214,16 @@ Node BagReduction::reduceAggregateOperator(Node node)
   const std::vector<uint32_t>& indices =
       node.getOperator().getConst<TableAggregateOp>().getIndices();
 
-  Node t1 = bvm->mkBoundVar<FirstIndexVarAttribute>(node, "t1", elementType);
-  Node t2 = bvm->mkBoundVar<SecondIndexVarAttribute>(node, "t2", elementType);
-  Node list = nm->mkNode(BOUND_VAR_LIST, t1, t2);
-  Node body = nm->mkConst(true);
-  for (uint32_t i : indices)
-  {
-    Node select1 = datatypes::TupleUtils::nthElementOfTuple(t1, i);
-    Node select2 = datatypes::TupleUtils::nthElementOfTuple(t2, i);
-    Node equal = select1.eqNode(select2);
-    body = body.andNode(equal);
-  }
-
-  Node lambda = nm->mkNode(LAMBDA, list, body);
-  Node partition = nm->mkNode(BAG_PARTITION, lambda, A);
+  Node groupOp = nm->mkConst(TableGroupOp(indices));
+  Node group = nm->mkNode(TABLE_GROUP, {groupOp, A});
 
   Node bag = bvm->mkBoundVar<FirstIndexVarAttribute>(
-      partition, "bag", nm->mkBagType(elementType));
+      group, "bag", nm->mkBagType(elementType));
   Node foldList = nm->mkNode(BOUND_VAR_LIST, bag);
   Node foldBody = nm->mkNode(BAG_FOLD, function, initialValue, bag);
 
   Node fold = nm->mkNode(LAMBDA, foldList, foldBody);
-  Node map = nm->mkNode(BAG_MAP, fold, partition);
+  Node map = nm->mkNode(BAG_MAP, fold, group);
   return map;
 }
 
index c3f49b0a443770aeb3321a2b3b704310c6bf1426..cf391a120a3b640e4dce5c79112258dd63016937 100644 (file)
@@ -98,16 +98,11 @@ class BagReduction
   static Node reduceCardOperator(Node node, std::vector<Node>& asserts);
   /**
    * @param node of the form ((_ table.aggr n1 ... nk) f initial A))
-   * @return reduction term that uses map, fold, and partition using
-   * tuple projection as the equivalence relation as follows:
+   * @return reduction term that uses map, fold, and group operators
+   * as follows:
    * (bag.map
    *   (lambda ((B Table)) (bag.fold f initial B))
-   *   (bag.partition
-   *     (lambda ((t1 Tuple) (t2 Tuple)) ; equivalence relation
-   *             (=
-   *               ((_ tuple.project n1 ... nk) t1)
-   *               ((_ tuple.project n1 ... nk) t2)))
-   *     A))
+   *   ((_ table.group n1 ... nk) A))
    */
   static Node reduceAggregateOperator(Node node);
 };
index 0987bccfc8c32cd89f23f8a90f413583f2a88927..4935a24d456a9b1c17f20e699d6264f19fa0f5f9 100644 (file)
@@ -146,6 +146,7 @@ Node BagsUtils::evaluate(Rewriter* rewriter, TNode n)
     case BAG_FOLD: return evaluateBagFold(n);
     case TABLE_PRODUCT: return evaluateProduct(n);
     case TABLE_JOIN: return evaluateJoin(rewriter, n);
+    case TABLE_GROUP: return evaluateGroup(rewriter, n);
     case TABLE_PROJECT: return evaluateTableProject(n);
     default: break;
   }
@@ -974,6 +975,84 @@ Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n)
   return ret;
 }
 
+Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+  TypeNode partitionType = n.getType();
+
+  if (A.getKind() == BAG_EMPTY)
+  {
+    // return a nonempty partition
+    return nm->mkNode(BAG_MAKE, A, nm->mkConstInt(Rational(1)));
+  }
+
+  std::vector<uint32_t> indices =
+      n.getOperator().getConst<TableGroupOp>().getIndices();
+
+  std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
+  Trace("bags-group") << "elements: " << elements << std::endl;
+  // a simple map from elements to equivalent classes with this invariant:
+  // each key element must appear exactly once in one of the values.
+  std::map<Node, std::set<Node>> sets;
+  std::set<Node> emptyClass;
+  for (const auto& pair : elements)
+  {
+    // initially each singleton element is an equivalence class
+    sets[pair.first] = {pair.first};
+  }
+  for (std::map<Node, Rational>::iterator i = elements.begin();
+       i != elements.end();
+       ++i)
+  {
+    if (sets[i->first].empty())
+    {
+      // skip this element since its equivalent class has already been processed
+      continue;
+    }
+    std::map<Node, Rational>::iterator j = i;
+    ++j;
+    while (j != elements.end())
+    {
+      if (TupleUtils::sameProjection(indices, i->first, j->first))
+      {
+        // add element j to the equivalent class
+        sets[i->first].insert(j->first);
+        // mark the equivalent class of j as processed
+        sets[j->first] = emptyClass;
+      }
+      ++j;
+    }
+  }
+
+  // construct the partition parts
+  std::map<Node, Rational> parts;
+  for (std::pair<Node, std::set<Node>> pair : sets)
+  {
+    const std::set<Node>& eqc = pair.second;
+    if (eqc.empty())
+    {
+      continue;
+    }
+    std::vector<Node> bags;
+    for (const Node& node : eqc)
+    {
+      Node bag = nm->mkNode(BAG_MAKE, node, nm->mkConstInt(elements[node]));
+      bags.push_back(bag);
+    }
+    Node part = computeDisjointUnion(bagType, bags);
+    // each part in the partitions has multiplicity one
+    parts[part] = Rational(1);
+  }
+  Node ret = constructConstantBagFromElements(partitionType, parts);
+  Trace("bags-partition") << "ret: " << ret << std::endl;
+  return ret;
+}
+
 Node BagsUtils::evaluateTableProject(TNode n)
 {
   Assert(n.getKind() == TABLE_PROJECT);
index 23f21371b75d14a8656405de9ae5dfb84b725baf..4da592d5ae69d6928cd0611c4f6121c49dcf4510 100644 (file)
@@ -132,6 +132,14 @@ class BagsUtils
    */
   static Node evaluateJoin(Rewriter* rewriter, TNode n);
 
+  /**
+   * @param n of the form ((_ table.group (n_1 ... n_k) ) A) where A is a
+   * constant table
+   * @return a partition of A such that each part contains tuples with the same
+   * projection with indices n_1 ... n_k
+   */
+  static Node evaluateGroup(Rewriter* rewriter, TNode n);
+
   /**
    * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a
    * constant
index d3a98a311cb4453f087225a1172467a410c56532..2522585bdea937b9266ff5dfd4137a6bf9d2c675 100644 (file)
@@ -144,6 +144,16 @@ constant TABLE_JOIN_OP \
 
 parameterized TABLE_JOIN TABLE_JOIN_OP 2 "table join"
 
+# table.group operator
+constant TABLE_GROUP_OP \
+  class \
+  TableGroupOp \
+  ::cvc5::internal::TableGroupOpHashFunction \
+  "theory/bags/table_project_op.h" \
+  "operator for TABLE_GROUP; payload is an instance of the cvc5::internal::TableGroupOp class"
+
+parameterized TABLE_GROUP TABLE_GROUP_OP 1 "table group"
+
 typerule TABLE_PRODUCT              ::cvc5::internal::theory::bags::TableProductTypeRule
 typerule TABLE_PROJECT_OP           "SimpleTypeRule<RBuiltinOperator>"
 typerule TABLE_PROJECT              ::cvc5::internal::theory::bags::TableProjectTypeRule
@@ -151,5 +161,7 @@ typerule TABLE_AGGREGATE_OP         "SimpleTypeRule<RBuiltinOperator>"
 typerule TABLE_AGGREGATE            ::cvc5::internal::theory::bags::TableAggregateTypeRule
 typerule TABLE_JOIN_OP              "SimpleTypeRule<RBuiltinOperator>"
 typerule TABLE_JOIN                 ::cvc5::internal::theory::bags::TableJoinTypeRule
+typerule TABLE_GROUP_OP             "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_GROUP                ::cvc5::internal::theory::bags::TableGroupTypeRule
 
 endtheory
index f57fc1206d34bd1eb09bb835d42b776b0f3a20c0..5fe8bae13e08a1e8fbc65ff913c7628517492f5a 100644 (file)
@@ -119,7 +119,7 @@ void SolverState::collectDisequalBagTerms()
         TypeNode elementType = A.getType().getBagElementType();
         SkolemManager* sm = d_nm->getSkolemManager();
         Node skolem = sm->mkSkolemFunction(
-            SkolemFunId::BAG_DEQ_DIFF, elementType, {A, B});
+            SkolemFunId::BAGS_DEQ_DIFF, elementType, {A, B});
         d_deq[equal] = skolem;
       }
     }
index 72700be9d36dd8278e1b3713d04cb738ad852a32..9c19cdba7abb2e9dc77458f616e7bb9aec3b15a8 100644 (file)
@@ -32,4 +32,9 @@ TableJoinOp::TableJoinOp(std::vector<uint32_t> indices)
 {
 }
 
+TableGroupOp::TableGroupOp(std::vector<uint32_t> indices)
+    : ProjectOp(std::move(indices))
+{
+}
+
 }  // namespace cvc5::internal
index 10c45f915b74e942139f56df3ce803f67832e7d8..03f2f556120791c8cb76f8fd2469bfdc92d51ca2 100644 (file)
@@ -54,7 +54,6 @@ struct TableAggregateOpHashFunction : public ProjectOpHashFunction
 {
 }; /* struct TableAggregateOpHashFunction */
 
-
 class TableJoinOp : public ProjectOp
 {
  public:
@@ -69,6 +68,19 @@ struct TableJoinOpHashFunction : public ProjectOpHashFunction
 {
 }; /* struct TableJoinOpHashFunction */
 
+class TableGroupOp : public ProjectOp
+{
+ public:
+  explicit TableGroupOp(std::vector<uint32_t> indices);
+  TableGroupOp(const TableGroupOp& op) = default;
+}; /* class TableGroupOp */
+
+/**
+ * Hash function for the TableGroupOpHashFunction objects.
+ */
+struct TableGroupOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TableGroupOpHashFunction */
 
 }  // namespace cvc5::internal
 
index c1d34eac8ad2e174ebd5956163e0ffdeb28b36c1..1581a091bddabbb5ee699facc0ced0dd56447894 100644 (file)
@@ -84,6 +84,7 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(TABLE_PROJECT);
   d_equalityEngine->addFunctionKind(TABLE_AGGREGATE);
   d_equalityEngine->addFunctionKind(TABLE_JOIN);
+  d_equalityEngine->addFunctionKind(TABLE_GROUP);
 }
 
 TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
index ee3d9b5a859cdc3d242f25950feedde0b5202c32..0bcf956122f14d130b1bf0b7a966e9bf40b90871 100644 (file)
@@ -707,6 +707,39 @@ TypeNode TableJoinTypeRule::computeType(NodeManager* nm, TNode n, bool check)
   return nm->mkBagType(retTupleType);
 }
 
+TypeNode TableGroupTypeRule::computeType(NodeManager* nm, TNode n, bool check)
+{
+  Assert(n.getKind() == kind::TABLE_GROUP && n.hasOperator()
+         && n.getOperator().getKind() == kind::TABLE_GROUP_OP);
+  TableGroupOp op = n.getOperator().getConst<TableGroupOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+
+  TypeNode bagType = n[0].getType(check);
+
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      std::stringstream ss;
+      ss << "TABLE_GROUP operator expects a table. Found '" << n[0]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = bagType.getBagElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "TABLE_GROUP operator expects a table. Found '" << n[0]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TupleUtils::checkTypeIndices(n, tupleType, indices);
+  }
+  return nm->mkBagType(bagType);
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;
index 445d627db4a4d2b4ada5db1e3b482da746263bd4..86fa4282297902e3c3dac23e8ad3191409e4443e 100644 (file)
@@ -216,6 +216,17 @@ struct TableJoinTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct TableJoinTypeRule */
 
+/**
+ * Table group operator is indexed by a list of indices (n_1, ..., n_k). It
+ * ensures that the argument is a table whose arity is greater than each n_i for
+ * i = 1, ..., k. If the passed table is of type T, then the returned type is
+ * (Bag T), i.e., bag of tables.
+ */
+struct TableGroupTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct TableGroupTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index 838e840c3799306e90fefda6540645272696dd8c..74024f5083bf0c308d6e3231b6431f97f4d4b227 100644 (file)
@@ -145,6 +145,23 @@ std::vector<Node> TupleUtils::getTupleElements(Node tuple1, Node tuple2)
   return elements;
 }
 
+bool TupleUtils::sameProjection(const std::vector<uint32_t>& indices,
+                                Node tuple1,
+                                Node tuple2)
+{
+  Assert(tuple1.isConst() && tuple2.isConst())
+      << "Both " << tuple1 << " and " << tuple2 << " are not constants"
+      << std::endl;
+  for (uint32_t index : indices)
+  {
+    if (tuple1[index] != tuple2[index])
+    {
+      return false;
+    }
+  }
+  return true;
+}
+
 Node TupleUtils::constructTupleFromElements(TypeNode tupleType,
                                             const std::vector<Node>& elements,
                                             size_t start,
index 9afbd59fe724d74f54216feed8500776b114786d..a7f76ccd241f194aecf9bb38adaee7068d4be705 100644 (file)
@@ -81,6 +81,18 @@ class TupleUtils
    */
   static std::vector<Node> getTupleElements(Node tuple1, Node tuple2);
 
+  /**
+   * @param indices a list of indices for projected elements n_1, ..., n_k
+   * @param tuple1 a constant tuple node
+   * @param tuple2 a constant tuple node
+   * @return a boolean representing the equality of
+   * ((_ tuple.projection n_1 ... n_k) tuple1) and
+   * ((_ tuple.projection n_1 ... n_k) tuple2).
+   */
+  static bool sameProjection(const std::vector<uint32_t>& indices,
+                             Node tuple1,
+                             Node tuple2);
+
   /**
    * construct a tuple from a list of elements
    * @param tupleType the type of the returned tuple
index a9d2f4f6e2e932a4cbfd57fd896fecff843ae3b3..82e07c46ed53b4c72032a59a696f2ba5809846e9 100644 (file)
@@ -1828,6 +1828,7 @@ set(regress_1_tests
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
   regress1/bags/table_aggregate1.smt2
+  regress1/bags/table_group1.smt2
   regress1/bags/table_join1.smt2
   regress1/bags/table_join2.smt2
   regress1/bags/table_join3.smt2
diff --git a/test/regress/cli/regress1/bags/table_group1.smt2 b/test/regress/cli/regress1/bags/table_group1.smt2
new file mode 100644 (file)
index 0000000..5bb0b44
--- /dev/null
@@ -0,0 +1,116 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(define-fun truthTable () (Table String String String)
+  (bag.union_disjoint
+   (bag (tuple "A" "X" "0") 2)
+   (bag (tuple "A" "X" "1") 2)
+   (bag (tuple "A" "Y" "0") 2)
+   (bag (tuple "A" "Y" "1") 2)
+   (bag (tuple "B" "X" "0") 2)
+   (bag (tuple "B" "X" "1") 2)
+   (bag (tuple "B" "Y" "0") 2)
+   (bag (tuple "B" "Y" "1") 2)))
+
+; parition by first column
+(assert
+ (= ((_ table.group 0) truthTable)
+    (bag.union_disjoint
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "X" "0") 2)
+                          (bag (tuple "A" "X" "1") 2)
+                          (bag (tuple "A" "Y" "0") 2)
+                          (bag (tuple "A" "Y" "1") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "B" "X" "0") 2)
+                          (bag (tuple "B" "X" "1") 2)
+                          (bag (tuple "B" "Y" "0") 2)
+                          (bag (tuple "B" "Y" "1") 2))
+      1))))
+
+; parition by second column
+(assert
+ (= ((_ table.group 1) truthTable)
+    (bag.union_disjoint
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "X" "0") 2)
+                          (bag (tuple "A" "X" "1") 2)
+                          (bag (tuple "B" "X" "0") 2)
+                          (bag (tuple "B" "X" "1") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "Y" "0") 2)
+                          (bag (tuple "A" "Y" "1") 2)
+                          (bag (tuple "B" "Y" "0") 2)
+                          (bag (tuple "B" "Y" "1") 2))
+      1))))
+
+; parition by third column
+(assert
+ (= ((_ table.group 2) truthTable)
+    (bag.union_disjoint
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "X" "0") 2)
+                          (bag (tuple "A" "Y" "0") 2)
+                          (bag (tuple "B" "X" "0") 2)
+                          (bag (tuple "B" "Y" "0") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "X" "1") 2)
+                          (bag (tuple "A" "Y" "1") 2)
+                          (bag (tuple "B" "X" "1") 2)
+                          (bag (tuple "B" "Y" "1") 2))
+      1))))
+
+; parition by first,second columns
+(assert
+ (= ((_ table.group 0 1) truthTable)
+    (bag.union_disjoint
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "X" "0") 2)
+                          (bag (tuple "A" "X" "1") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "A" "Y" "0") 2)
+                          (bag (tuple "A" "Y" "1") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "B" "X" "0") 2)
+                          (bag (tuple "B" "X" "1") 2))
+      1)
+     (bag
+      (bag.union_disjoint (bag (tuple "B" "Y" "0") 2)
+                          (bag (tuple "B" "Y" "1") 2))
+      1))))
+
+; parition by no column
+(assert
+ (= (table.group truthTable)
+    (bag
+     (bag.union_disjoint
+      (bag (tuple "A" "X" "0") 2)
+      (bag (tuple "A" "X" "1") 2)
+      (bag (tuple "A" "Y" "0") 2)
+      (bag (tuple "A" "Y" "1") 2)
+      (bag (tuple "B" "X" "0") 2)
+      (bag (tuple "B" "X" "1") 2)
+      (bag (tuple "B" "Y" "0") 2)
+      (bag (tuple "B" "Y" "1") 2))
+     1)))
+
+; parition by all columns
+(assert
+ (= ((_ table.group 0 1 2) truthTable)
+    (bag.union_disjoint
+     (bag (bag (tuple "A" "X" "0") 2) 1)
+     (bag (bag (tuple "A" "X" "1") 2) 1)
+     (bag (bag (tuple "A" "Y" "0") 2) 1)
+     (bag (bag (tuple "A" "Y" "1") 2) 1)
+     (bag (bag (tuple "B" "X" "0") 2) 1)
+     (bag (bag (tuple "B" "X" "1") 2) 1)
+     (bag (bag (tuple "B" "Y" "0") 2) 1)
+     (bag (bag (tuple "B" "Y" "1") 2) 1))))
+
+(check-sat)