From b2e25ec6ffadc4bbc9e45962da384a8c192d042e Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Tue, 24 May 2022 09:51:49 -0500 Subject: [PATCH] Add table.group operator (#8731) --- src/api/cpp/cvc5.cpp | 6 + src/api/cpp/cvc5_kind.h | 68 ++++++++-- src/expr/skolem_manager.cpp | 2 +- src/expr/skolem_manager.h | 2 +- src/parser/smt2/Smt2.g | 14 +++ src/parser/smt2/smt2.cpp | 4 +- src/printer/smt2/smt2_printer.cpp | 16 +++ src/theory/bags/bag_reduction.cpp | 20 +-- src/theory/bags/bag_reduction.h | 11 +- src/theory/bags/bags_utils.cpp | 79 ++++++++++++ src/theory/bags/bags_utils.h | 8 ++ src/theory/bags/kinds | 12 ++ src/theory/bags/solver_state.cpp | 2 +- src/theory/bags/table_project_op.cpp | 5 + src/theory/bags/table_project_op.h | 14 ++- src/theory/bags/theory_bags.cpp | 1 + src/theory/bags/theory_bags_type_rules.cpp | 33 +++++ src/theory/bags/theory_bags_type_rules.h | 11 ++ src/theory/datatypes/tuple_utils.cpp | 17 +++ src/theory/datatypes/tuple_utils.h | 12 ++ test/regress/cli/CMakeLists.txt | 1 + .../cli/regress1/bags/table_group1.smt2 | 116 ++++++++++++++++++ 22 files changed, 417 insertions(+), 37 deletions(-) create mode 100644 test/regress/cli/regress1/bags/table_group1.smt2 diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 980e6468b..b6c8fc95d 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -334,6 +334,7 @@ const static std::unordered_map> KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT), KIND_ENUM(TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE), KIND_ENUM(TABLE_JOIN, internal::Kind::TABLE_JOIN), + KIND_ENUM(TABLE_GROUP, internal::Kind::TABLE_GROUP), /* Strings ---------------------------------------------------------- */ KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT), KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP), @@ -657,6 +658,8 @@ const static std::unordered_map& args) const case TABLE_JOIN: res = mkOpHelper(kind, internal::TableJoinOp(args)); break; + case TABLE_GROUP: + res = mkOpHelper(kind, internal::TableGroupOp(args)); + break; default: if (nargs == 0) { diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 2e1fc435f..40797e027 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -3743,16 +3743,22 @@ enum Kind : int32_t * * - Create Op of this kind with: * - Solver::mkOp(Kind, const std::vector&) const + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst */ TABLE_PROJECT, /** + * \rst + * * Table aggregate operator has the form - * :math:`((\_ \; table.aggr \; n_1 ... n_k) f i A)` + * :math:`((\_ \; table.aggr \; n_1 ... n_k) \; f \; i \; A)` * where :math:`n_1, ..., n_k` are natural numbers, * :math:`f` is a function of type * :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)`, * :math:`i` has the type :math:`T`, - * and :math`A` has type :math:`Table T_1 ... T_j`. + * and :math:`A` has type :math:`(Table \; T_1 \; ... \; T_j)`. * The returned type is :math:`(Bag \; T)`. * * This operator aggregates elements in A that have the same tuple projection @@ -3760,43 +3766,89 @@ enum Kind : int32_t * and initial value :math:`i`. * * - Arity: ``3`` + * * - ``1:`` Term of sort :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)` * - ``2:`` Term of Sort :math:`T` * - ``3:`` Term of table sort :math:`Table T_1 ... T_j` * * - Indices: ``n`` * - ``1..n:`` Indices of the projection - * + * \endrst * - Create Term of this Kind with: * - Solver::mkTerm(const Op&, const std::vector&) const * * - Create Op of this kind with: * - Solver::mkOp(Kind, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst */ TABLE_AGGREGATE, /** - * Table join operator has the form + * \rst + * Table join operator has the form * :math:`((\_ \; table.join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)` - * where * :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers, - * and A, B are tables. + * where :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers, + * and :math:`A, B` are tables. * This operator filters the product of two bags based on the equality of - * projected tuples using indices :math:`m_1, \dots, m_k` in table A, - * and indices :math:`n_1, \dots, n_k` in table B. + * projected tuples using indices :math:`m_1, \dots, m_k` in table :math:`A`, + * and indices :math:`n_1, \dots, n_k` in table :math:`B`. * * - Arity: ``2`` + * * - ``1:`` Term of table Sort + * * - ``2:`` Term of table Sort * * - Indices: ``n`` * - ``1..n:`` Indices of the projection * + * \endrst * - Create Term of this Kind with: * - Solver::mkTerm(const Op&, const std::vector&) const * * - Create Op of this kind with: * - Solver::mkOp(Kind, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst */ TABLE_JOIN, + /** + * Table group + * + * \rst + * :math:`((\_ \; table.group \; n_1 \; \dots \; n_k) \; A)` partitions tuples + * of table :math:`A` such that tuples that have the same projection + * with indices :math:`n_1 \; \dots \; n_k` are in the same part. + * It returns a bag of tables of type :math:`(Bag \; T)` where + * :math:`T` is the type of :math:`A`. + * + * - Arity: ``1`` + * + * - ``1:`` Term of table sort + * + * - Indices: ``n`` + * + * - ``1..n:`` Indices of the projection + * + * \endrst + * + * - Create Term of this Kind with: + * + * - Solver::mkTerm(Kind, const std::vector&) const + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst + */ + TABLE_GROUP, /* Strings --------------------------------------------------------------- */ /** diff --git a/src/expr/skolem_manager.cpp b/src/expr/skolem_manager.cpp index 993d9b4d8..58276262b 100644 --- a/src/expr/skolem_manager.cpp +++ b/src/expr/skolem_manager.cpp @@ -94,7 +94,7 @@ const char* toString(SkolemFunId id) case SkolemFunId::BAGS_MAP_PREIMAGE_SIZE: return "BAGS_MAP_PREIMAGE_SIZE"; case SkolemFunId::BAGS_MAP_PREIMAGE_INDEX: return "BAGS_MAP_PREIMAGE_INDEX"; case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM"; - case SkolemFunId::BAG_DEQ_DIFF: return "BAG_DEQ_DIFF"; + case SkolemFunId::BAGS_DEQ_DIFF: return "BAGS_DEQ_DIFF"; case SkolemFunId::SETS_CHOOSE: return "SETS_CHOOSE"; case SkolemFunId::SETS_DEQ_DIFF: return "SETS_DEQ_DIFF"; case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED"; diff --git a/src/expr/skolem_manager.h b/src/expr/skolem_manager.h index 2480d1d0c..9d4f225f9 100644 --- a/src/expr/skolem_manager.h +++ b/src/expr/skolem_manager.h @@ -168,7 +168,7 @@ enum class SkolemFunId */ BAGS_MAP_SUM, /** bag diff to witness (not (= A B)) */ - BAG_DEQ_DIFF, + BAGS_DEQ_DIFF, /** An interpreted function for bag.choose operator: * (choose A) is expanded as * (witness ((x elementType)) diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 239d36468..ee0738b11 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1425,6 +1425,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2] cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_JOIN, indices); expr = SOLVER->mkTerm(op, {expr}); } + | LPAREN_TOK TABLE_GROUP_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_GROUP, indices); + expr = SOLVER->mkTerm(op, {expr}); + } | /* an atomic term (a term with no subterms) */ termAtomic[atomTerm] { expr = atomTerm; } ; @@ -1587,6 +1593,13 @@ identifier[cvc5::ParseOp& p] p.d_kind = cvc5::TABLE_JOIN; p.d_op = SOLVER->mkOp(cvc5::TABLE_JOIN, numerals); } + | TABLE_GROUP_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ table.group i_1 ... i_n) where + // i_1, ..., j_n are numerals + p.d_kind = cvc5::TABLE_GROUP; + p.d_op = SOLVER->mkOp(cvc5::TABLE_GROUP, numerals); + } | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals] { cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName); @@ -2198,6 +2211,7 @@ TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATA TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project'; TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr'; TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join'; +TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group'; FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card'; HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->'; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index cf28d0ac8..66f2db214 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -636,6 +636,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(cvc5::BAG_FOLD, "bag.fold"); addOperator(cvc5::BAG_PARTITION, "bag.partition"); addOperator(cvc5::TABLE_PRODUCT, "table.product"); + addOperator(cvc5::BAG_PARTITION, "table.group"); } if (d_logic.isTheoryEnabled(internal::theory::THEORY_STRINGS)) { @@ -1126,7 +1127,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) return ret; } else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT - || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN) + || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN + || p.d_kind == cvc5::TABLE_GROUP) { cvc5::Term ret = d_solver->mkTerm(p.d_op, args); Trace("parser") << "applyParseOp: return projection " << ret << std::endl; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index ec29595e7..279046300 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -792,6 +792,21 @@ void Smt2Printer::toStream(std::ostream& out, } return; } + case kind::TABLE_GROUP: + { + TableGroupOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (table.group A) + out << "table.group " << n[0] << ")"; + } + else + { + // e.g. ((_ table.group 0 1 2 3) A) + out << "(_ table.group" << op << ") " << n[0] << ")"; + } + return; + } case kind::CONSTRUCTOR_TYPE: { out << n[n.getNumChildren()-1]; @@ -1155,6 +1170,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::TABLE_PROJECT: return "table.project"; case kind::TABLE_AGGREGATE: return "table.aggr"; case kind::TABLE_JOIN: return "table.join"; + case kind::TABLE_GROUP: return "table.group"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp index 89a422301..e7ccba325 100644 --- a/src/theory/bags/bag_reduction.cpp +++ b/src/theory/bags/bag_reduction.cpp @@ -214,28 +214,16 @@ Node BagReduction::reduceAggregateOperator(Node node) const std::vector& indices = node.getOperator().getConst().getIndices(); - Node t1 = bvm->mkBoundVar(node, "t1", elementType); - Node t2 = bvm->mkBoundVar(node, "t2", elementType); - Node list = nm->mkNode(BOUND_VAR_LIST, t1, t2); - Node body = nm->mkConst(true); - for (uint32_t i : indices) - { - Node select1 = datatypes::TupleUtils::nthElementOfTuple(t1, i); - Node select2 = datatypes::TupleUtils::nthElementOfTuple(t2, i); - Node equal = select1.eqNode(select2); - body = body.andNode(equal); - } - - Node lambda = nm->mkNode(LAMBDA, list, body); - Node partition = nm->mkNode(BAG_PARTITION, lambda, A); + Node groupOp = nm->mkConst(TableGroupOp(indices)); + Node group = nm->mkNode(TABLE_GROUP, {groupOp, A}); Node bag = bvm->mkBoundVar( - partition, "bag", nm->mkBagType(elementType)); + group, "bag", nm->mkBagType(elementType)); Node foldList = nm->mkNode(BOUND_VAR_LIST, bag); Node foldBody = nm->mkNode(BAG_FOLD, function, initialValue, bag); Node fold = nm->mkNode(LAMBDA, foldList, foldBody); - Node map = nm->mkNode(BAG_MAP, fold, partition); + Node map = nm->mkNode(BAG_MAP, fold, group); return map; } diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h index c3f49b0a4..cf391a120 100644 --- a/src/theory/bags/bag_reduction.h +++ b/src/theory/bags/bag_reduction.h @@ -98,16 +98,11 @@ class BagReduction static Node reduceCardOperator(Node node, std::vector& asserts); /** * @param node of the form ((_ table.aggr n1 ... nk) f initial A)) - * @return reduction term that uses map, fold, and partition using - * tuple projection as the equivalence relation as follows: + * @return reduction term that uses map, fold, and group operators + * as follows: * (bag.map * (lambda ((B Table)) (bag.fold f initial B)) - * (bag.partition - * (lambda ((t1 Tuple) (t2 Tuple)) ; equivalence relation - * (= - * ((_ tuple.project n1 ... nk) t1) - * ((_ tuple.project n1 ... nk) t2))) - * A)) + * ((_ table.group n1 ... nk) A)) */ static Node reduceAggregateOperator(Node node); }; diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index 0987bccfc..4935a24d4 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -146,6 +146,7 @@ Node BagsUtils::evaluate(Rewriter* rewriter, TNode n) case BAG_FOLD: return evaluateBagFold(n); case TABLE_PRODUCT: return evaluateProduct(n); case TABLE_JOIN: return evaluateJoin(rewriter, n); + case TABLE_GROUP: return evaluateGroup(rewriter, n); case TABLE_PROJECT: return evaluateTableProject(n); default: break; } @@ -974,6 +975,84 @@ Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n) return ret; } +Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n) +{ + Assert(n.getKind() == TABLE_GROUP); + + NodeManager* nm = NodeManager::currentNM(); + + Node A = n[0]; + TypeNode bagType = A.getType(); + TypeNode partitionType = n.getType(); + + if (A.getKind() == BAG_EMPTY) + { + // return a nonempty partition + return nm->mkNode(BAG_MAKE, A, nm->mkConstInt(Rational(1))); + } + + std::vector indices = + n.getOperator().getConst().getIndices(); + + std::map elements = BagsUtils::getBagElements(A); + Trace("bags-group") << "elements: " << elements << std::endl; + // a simple map from elements to equivalent classes with this invariant: + // each key element must appear exactly once in one of the values. + std::map> sets; + std::set emptyClass; + for (const auto& pair : elements) + { + // initially each singleton element is an equivalence class + sets[pair.first] = {pair.first}; + } + for (std::map::iterator i = elements.begin(); + i != elements.end(); + ++i) + { + if (sets[i->first].empty()) + { + // skip this element since its equivalent class has already been processed + continue; + } + std::map::iterator j = i; + ++j; + while (j != elements.end()) + { + if (TupleUtils::sameProjection(indices, i->first, j->first)) + { + // add element j to the equivalent class + sets[i->first].insert(j->first); + // mark the equivalent class of j as processed + sets[j->first] = emptyClass; + } + ++j; + } + } + + // construct the partition parts + std::map parts; + for (std::pair> pair : sets) + { + const std::set& eqc = pair.second; + if (eqc.empty()) + { + continue; + } + std::vector bags; + for (const Node& node : eqc) + { + Node bag = nm->mkNode(BAG_MAKE, node, nm->mkConstInt(elements[node])); + bags.push_back(bag); + } + Node part = computeDisjointUnion(bagType, bags); + // each part in the partitions has multiplicity one + parts[part] = Rational(1); + } + Node ret = constructConstantBagFromElements(partitionType, parts); + Trace("bags-partition") << "ret: " << ret << std::endl; + return ret; +} + Node BagsUtils::evaluateTableProject(TNode n) { Assert(n.getKind() == TABLE_PROJECT); diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h index 23f21371b..4da592d5a 100644 --- a/src/theory/bags/bags_utils.h +++ b/src/theory/bags/bags_utils.h @@ -132,6 +132,14 @@ class BagsUtils */ static Node evaluateJoin(Rewriter* rewriter, TNode n); + /** + * @param n of the form ((_ table.group (n_1 ... n_k) ) A) where A is a + * constant table + * @return a partition of A such that each part contains tuples with the same + * projection with indices n_1 ... n_k + */ + static Node evaluateGroup(Rewriter* rewriter, TNode n); + /** * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a * constant diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index d3a98a311..2522585bd 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -144,6 +144,16 @@ constant TABLE_JOIN_OP \ parameterized TABLE_JOIN TABLE_JOIN_OP 2 "table join" +# table.group operator +constant TABLE_GROUP_OP \ + class \ + TableGroupOp \ + ::cvc5::internal::TableGroupOpHashFunction \ + "theory/bags/table_project_op.h" \ + "operator for TABLE_GROUP; payload is an instance of the cvc5::internal::TableGroupOp class" + +parameterized TABLE_GROUP TABLE_GROUP_OP 1 "table group" + typerule TABLE_PRODUCT ::cvc5::internal::theory::bags::TableProductTypeRule typerule TABLE_PROJECT_OP "SimpleTypeRule" typerule TABLE_PROJECT ::cvc5::internal::theory::bags::TableProjectTypeRule @@ -151,5 +161,7 @@ typerule TABLE_AGGREGATE_OP "SimpleTypeRule" typerule TABLE_AGGREGATE ::cvc5::internal::theory::bags::TableAggregateTypeRule typerule TABLE_JOIN_OP "SimpleTypeRule" typerule TABLE_JOIN ::cvc5::internal::theory::bags::TableJoinTypeRule +typerule TABLE_GROUP_OP "SimpleTypeRule" +typerule TABLE_GROUP ::cvc5::internal::theory::bags::TableGroupTypeRule endtheory diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp index f57fc1206..5fe8bae13 100644 --- a/src/theory/bags/solver_state.cpp +++ b/src/theory/bags/solver_state.cpp @@ -119,7 +119,7 @@ void SolverState::collectDisequalBagTerms() TypeNode elementType = A.getType().getBagElementType(); SkolemManager* sm = d_nm->getSkolemManager(); Node skolem = sm->mkSkolemFunction( - SkolemFunId::BAG_DEQ_DIFF, elementType, {A, B}); + SkolemFunId::BAGS_DEQ_DIFF, elementType, {A, B}); d_deq[equal] = skolem; } } diff --git a/src/theory/bags/table_project_op.cpp b/src/theory/bags/table_project_op.cpp index 72700be9d..9c19cdba7 100644 --- a/src/theory/bags/table_project_op.cpp +++ b/src/theory/bags/table_project_op.cpp @@ -32,4 +32,9 @@ TableJoinOp::TableJoinOp(std::vector indices) { } +TableGroupOp::TableGroupOp(std::vector indices) + : ProjectOp(std::move(indices)) +{ +} + } // namespace cvc5::internal diff --git a/src/theory/bags/table_project_op.h b/src/theory/bags/table_project_op.h index 10c45f915..03f2f5561 100644 --- a/src/theory/bags/table_project_op.h +++ b/src/theory/bags/table_project_op.h @@ -54,7 +54,6 @@ struct TableAggregateOpHashFunction : public ProjectOpHashFunction { }; /* struct TableAggregateOpHashFunction */ - class TableJoinOp : public ProjectOp { public: @@ -69,6 +68,19 @@ struct TableJoinOpHashFunction : public ProjectOpHashFunction { }; /* struct TableJoinOpHashFunction */ +class TableGroupOp : public ProjectOp +{ + public: + explicit TableGroupOp(std::vector indices); + TableGroupOp(const TableGroupOp& op) = default; +}; /* class TableGroupOp */ + +/** + * Hash function for the TableGroupOpHashFunction objects. + */ +struct TableGroupOpHashFunction : public ProjectOpHashFunction +{ +}; /* struct TableGroupOpHashFunction */ } // namespace cvc5::internal diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index c1d34eac8..1581a091b 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -84,6 +84,7 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(TABLE_PROJECT); d_equalityEngine->addFunctionKind(TABLE_AGGREGATE); d_equalityEngine->addFunctionKind(TABLE_JOIN); + d_equalityEngine->addFunctionKind(TABLE_GROUP); } TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index ee3d9b5a8..0bcf95612 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -707,6 +707,39 @@ TypeNode TableJoinTypeRule::computeType(NodeManager* nm, TNode n, bool check) return nm->mkBagType(retTupleType); } +TypeNode TableGroupTypeRule::computeType(NodeManager* nm, TNode n, bool check) +{ + Assert(n.getKind() == kind::TABLE_GROUP && n.hasOperator() + && n.getOperator().getKind() == kind::TABLE_GROUP_OP); + TableGroupOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + + TypeNode bagType = n[0].getType(check); + + if (check) + { + if (!bagType.isBag()) + { + std::stringstream ss; + ss << "TABLE_GROUP operator expects a table. Found '" << n[0] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode tupleType = bagType.getBagElementType(); + if (!tupleType.isTuple()) + { + std::stringstream ss; + ss << "TABLE_GROUP operator expects a table. Found '" << n[0] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TupleUtils::checkTypeIndices(n, tupleType, indices); + } + return nm->mkBagType(bagType); +} + Cardinality BagsProperties::computeCardinality(TypeNode type) { return Cardinality::INTEGERS; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 445d627db..86fa42822 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -216,6 +216,17 @@ struct TableJoinTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct TableJoinTypeRule */ +/** + * Table group operator is indexed by a list of indices (n_1, ..., n_k). It + * ensures that the argument is a table whose arity is greater than each n_i for + * i = 1, ..., k. If the passed table is of type T, then the returned type is + * (Bag T), i.e., bag of tables. + */ +struct TableGroupTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct TableGroupTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp index 838e840c3..74024f508 100644 --- a/src/theory/datatypes/tuple_utils.cpp +++ b/src/theory/datatypes/tuple_utils.cpp @@ -145,6 +145,23 @@ std::vector TupleUtils::getTupleElements(Node tuple1, Node tuple2) return elements; } +bool TupleUtils::sameProjection(const std::vector& indices, + Node tuple1, + Node tuple2) +{ + Assert(tuple1.isConst() && tuple2.isConst()) + << "Both " << tuple1 << " and " << tuple2 << " are not constants" + << std::endl; + for (uint32_t index : indices) + { + if (tuple1[index] != tuple2[index]) + { + return false; + } + } + return true; +} + Node TupleUtils::constructTupleFromElements(TypeNode tupleType, const std::vector& elements, size_t start, diff --git a/src/theory/datatypes/tuple_utils.h b/src/theory/datatypes/tuple_utils.h index 9afbd59fe..a7f76ccd2 100644 --- a/src/theory/datatypes/tuple_utils.h +++ b/src/theory/datatypes/tuple_utils.h @@ -81,6 +81,18 @@ class TupleUtils */ static std::vector getTupleElements(Node tuple1, Node tuple2); + /** + * @param indices a list of indices for projected elements n_1, ..., n_k + * @param tuple1 a constant tuple node + * @param tuple2 a constant tuple node + * @return a boolean representing the equality of + * ((_ tuple.projection n_1 ... n_k) tuple1) and + * ((_ tuple.projection n_1 ... n_k) tuple2). + */ + static bool sameProjection(const std::vector& indices, + Node tuple1, + Node tuple2); + /** * construct a tuple from a list of elements * @param tupleType the type of the returned tuple diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index a9d2f4f6e..82e07c46e 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -1828,6 +1828,7 @@ set(regress_1_tests regress1/bags/subbag1.smt2 regress1/bags/subbag2.smt2 regress1/bags/table_aggregate1.smt2 + regress1/bags/table_group1.smt2 regress1/bags/table_join1.smt2 regress1/bags/table_join2.smt2 regress1/bags/table_join3.smt2 diff --git a/test/regress/cli/regress1/bags/table_group1.smt2 b/test/regress/cli/regress1/bags/table_group1.smt2 new file mode 100644 index 000000000..5bb0b443a --- /dev/null +++ b/test/regress/cli/regress1/bags/table_group1.smt2 @@ -0,0 +1,116 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(define-fun truthTable () (Table String String String) + (bag.union_disjoint + (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "X" "1") 2) + (bag (tuple "A" "Y" "0") 2) + (bag (tuple "A" "Y" "1") 2) + (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "X" "1") 2) + (bag (tuple "B" "Y" "0") 2) + (bag (tuple "B" "Y" "1") 2))) + +; parition by first column +(assert + (= ((_ table.group 0) truthTable) + (bag.union_disjoint + (bag + (bag.union_disjoint (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "X" "1") 2) + (bag (tuple "A" "Y" "0") 2) + (bag (tuple "A" "Y" "1") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "X" "1") 2) + (bag (tuple "B" "Y" "0") 2) + (bag (tuple "B" "Y" "1") 2)) + 1)))) + +; parition by second column +(assert + (= ((_ table.group 1) truthTable) + (bag.union_disjoint + (bag + (bag.union_disjoint (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "X" "1") 2) + (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "X" "1") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "A" "Y" "0") 2) + (bag (tuple "A" "Y" "1") 2) + (bag (tuple "B" "Y" "0") 2) + (bag (tuple "B" "Y" "1") 2)) + 1)))) + +; parition by third column +(assert + (= ((_ table.group 2) truthTable) + (bag.union_disjoint + (bag + (bag.union_disjoint (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "Y" "0") 2) + (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "Y" "0") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "A" "X" "1") 2) + (bag (tuple "A" "Y" "1") 2) + (bag (tuple "B" "X" "1") 2) + (bag (tuple "B" "Y" "1") 2)) + 1)))) + +; parition by first,second columns +(assert + (= ((_ table.group 0 1) truthTable) + (bag.union_disjoint + (bag + (bag.union_disjoint (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "X" "1") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "A" "Y" "0") 2) + (bag (tuple "A" "Y" "1") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "X" "1") 2)) + 1) + (bag + (bag.union_disjoint (bag (tuple "B" "Y" "0") 2) + (bag (tuple "B" "Y" "1") 2)) + 1)))) + +; parition by no column +(assert + (= (table.group truthTable) + (bag + (bag.union_disjoint + (bag (tuple "A" "X" "0") 2) + (bag (tuple "A" "X" "1") 2) + (bag (tuple "A" "Y" "0") 2) + (bag (tuple "A" "Y" "1") 2) + (bag (tuple "B" "X" "0") 2) + (bag (tuple "B" "X" "1") 2) + (bag (tuple "B" "Y" "0") 2) + (bag (tuple "B" "Y" "1") 2)) + 1))) + +; parition by all columns +(assert + (= ((_ table.group 0 1 2) truthTable) + (bag.union_disjoint + (bag (bag (tuple "A" "X" "0") 2) 1) + (bag (bag (tuple "A" "X" "1") 2) 1) + (bag (bag (tuple "A" "Y" "0") 2) 1) + (bag (bag (tuple "A" "Y" "1") 2) 1) + (bag (bag (tuple "B" "X" "0") 2) 1) + (bag (bag (tuple "B" "X" "1") 2) 1) + (bag (bag (tuple "B" "Y" "0") 2) 1) + (bag (bag (tuple "B" "Y" "1") 2) 1)))) + +(check-sat) -- 2.30.2