From: mudathirmahgoub Date: Thu, 30 Jun 2022 02:03:36 +0000 (-0500) Subject: Add set.aggr operator to sets (#8878) X-Git-Tag: cvc5-1.0.1~26 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9e20f45b9dcbc3bb2b0ddfd96337454a3221baf2;p=cvc5.git Add set.aggr operator to sets (#8878) This PR depends on #8876 --- diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 1c45d2d4d..692eff021 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -310,6 +310,7 @@ const static std::unordered_map> KIND_ENUM(RELATION_JOIN_IMAGE, internal::Kind::RELATION_JOIN_IMAGE), KIND_ENUM(RELATION_IDEN, internal::Kind::RELATION_IDEN), KIND_ENUM(RELATION_GROUP, internal::Kind::RELATION_GROUP), + KIND_ENUM(RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE), /* Bags ------------------------------------------------------------- */ KIND_ENUM(BAG_UNION_MAX, internal::Kind::BAG_UNION_MAX), KIND_ENUM(BAG_UNION_DISJOINT, internal::Kind::BAG_UNION_DISJOINT), @@ -635,6 +636,8 @@ const static std::unordered_map s_op_kinds{ {REGEXP_REPEAT, internal::Kind::REGEXP_REPEAT_OP}, {REGEXP_LOOP, internal::Kind::REGEXP_LOOP_OP}, {TUPLE_PROJECT, internal::Kind::TUPLE_PROJECT_OP}, + {RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE_OP}, {RELATION_GROUP, internal::Kind::RELATION_GROUP_OP}, {TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP}, {TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP}, @@ -1953,6 +1957,7 @@ size_t Op::getNumIndicesHelper() const case FLOATINGPOINT_TO_FP_FROM_UBV: size = 2; break; case REGEXP_LOOP: size = 2; break; case TUPLE_PROJECT: + case RELATION_AGGREGATE: case RELATION_GROUP: case TABLE_AGGREGATE: case TABLE_GROUP: @@ -2114,6 +2119,7 @@ Term Op::getIndexHelper(size_t index) const break; } case TUPLE_PROJECT: + case RELATION_AGGREGATE: case RELATION_GROUP: case TABLE_AGGREGATE: case TABLE_GROUP: @@ -6156,6 +6162,7 @@ Op Solver::mkOp(Kind kind, const std::vector& args) const res = mkOpHelper(kind, internal::RegExpLoop(args[0], args[1])); break; case TUPLE_PROJECT: + case RELATION_AGGREGATE: case RELATION_GROUP: case TABLE_AGGREGATE: case TABLE_GROUP: diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index f71d6a447..4d4234f03 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -3361,6 +3361,43 @@ enum Kind : int32_t * \endrst */ RELATION_GROUP, + /** + * \rst + * + * Relation aggregate operator has the form + * :math:`((\_ \; rel.aggr \; n_1 ... n_k) \; f \; i \; A)` + * where :math:`n_1, ..., n_k` are natural numbers, + * :math:`f` is a function of type + * :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)`, + * :math:`i` has the type :math:`T`, + * and :math:`A` has type :math:`(Relation \; T_1 \; ... \; T_j)`. + * The returned type is :math:`(Set \; T)`. + * + * This operator aggregates elements in A that have the same tuple projection + * with indices n_1, ..., n_k using the combining function :math:`f`, + * and initial value :math:`i`. + * + * - Arity: ``3`` + * + * - ``1:`` Term of sort :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)` + * - ``2:`` Term of Sort :math:`T` + * - ``3:`` Term of relation sort :math:`Relation T_1 ... T_j` + * + * - Indices: ``n`` + * - ``1..n:`` Indices of the projection + * \endrst + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst + */ + RELATION_AGGREGATE, /* Bags ------------------------------------------------------------------ */ diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 875e41ee4..96421e6ce 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1435,6 +1435,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2] cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_GROUP, indices); expr = SOLVER->mkTerm(op, {expr}); } + | LPAREN_TOK RELATION_AGGREGATE_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, indices); + expr = SOLVER->mkTerm(op, {expr}); + } | /* an atomic term (a term with no subterms) */ termAtomic[atomTerm] { expr = atomTerm; } ; @@ -1611,6 +1617,13 @@ identifier[cvc5::ParseOp& p] p.d_kind = cvc5::RELATION_GROUP; p.d_op = SOLVER->mkOp(cvc5::RELATION_GROUP, numerals); } + | RELATION_AGGREGATE_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ rel.aggr i_1 ... i_n) where + // i_1, ..., i_n are numerals + p.d_kind = cvc5::RELATION_AGGREGATE; + p.d_op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, numerals); + } | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals] { cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName); @@ -2224,6 +2237,7 @@ TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BA TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join'; TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group'; RELATION_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.group'; +RELATION_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.aggr'; FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card'; HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->'; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 900ff1016..5f4ae42e4 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -1130,7 +1130,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) } else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN - || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP) + || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP + || p.d_kind == cvc5::RELATION_AGGREGATE) { cvc5::Term ret = d_solver->mkTerm(p.d_op, args); Trace("parser") << "applyParseOp: return projection " << ret << std::endl; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 5dc9252a9..3ffe7641c 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -779,7 +779,7 @@ void Smt2Printer::toStream(std::ostream& out, ProjectOp op = n.getOperator().getConst(); if (op.getIndices().empty()) { - // e.g. (table.project function initial_value bag) + // e.g. (table.aggr function initial_value bag) out << "table.aggr " << n[0] << " " << n[1] << " " << n[2] << ")"; } else @@ -835,6 +835,22 @@ void Smt2Printer::toStream(std::ostream& out, } return; } + case kind::RELATION_AGGREGATE: + { + ProjectOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (rel.aggr function initial_value bag) + out << "rel.aggr " << n[0] << " " << n[1] << " " << n[2] << ")"; + } + else + { + // e.g. ((_ rel.aggr 0) function initial_value bag) + out << "(_ rel.aggr" << op << ") " << n[0] << " " << n[1] << " " << n[2] + << ")"; + } + return; + } case kind::CONSTRUCTOR_TYPE: { out << n[n.getNumChildren()-1]; @@ -1175,6 +1191,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::RELATION_IDEN: return "rel.iden"; case kind::RELATION_JOIN_IMAGE: return "rel.join_image"; case kind::RELATION_GROUP: return "rel.group"; + case kind::RELATION_AGGREGATE: return "rel.aggr"; // bag theory case kind::BAG_TYPE: return "Bag"; diff --git a/src/theory/sets/kinds b/src/theory/sets/kinds index 0c0e17dc1..4c5899e96 100644 --- a/src/theory/sets/kinds +++ b/src/theory/sets/kinds @@ -95,10 +95,20 @@ constant RELATION_GROUP_OP \ ProjectOp+ \ ::cvc5::internal::ProjectOpHashFunction \ "theory/datatypes/project_op.h" \ - "operator for RELATION_GROUP; payload is an instance of the cvc5::internal::RelationGroupOp class" + "operator for RELATION_GROUP; payload is an instance of the cvc5::internal::ProjectOp class" parameterized RELATION_GROUP RELATION_GROUP_OP 1 "relation group" +# relation aggregate operator +constant RELATION_AGGREGATE_OP \ + class \ + ProjectOp+ \ + ::cvc5::internal::ProjectOpHashFunction \ + "theory/datatypes/project_op.h" \ + "operator for RELATION_AGGREGATE; payload is an instance of the cvc5::internal::ProjectOp class" + +parameterized RELATION_AGGREGATE RELATION_AGGREGATE_OP 3 "relation aggregate" + operator RELATION_JOIN 2 "relation join" operator RELATION_PRODUCT 2 "relation cartesian product" operator RELATION_TRANSPOSE 1 "relation transpose" @@ -132,6 +142,8 @@ typerule RELATION_JOIN_IMAGE ::cvc5::internal::theory::sets::JoinImageTypeRu typerule RELATION_IDEN ::cvc5::internal::theory::sets::RelIdenTypeRule typerule RELATION_GROUP_OP "SimpleTypeRule" typerule RELATION_GROUP ::cvc5::internal::theory::sets::RelationGroupTypeRule +typerule RELATION_AGGREGATE_OP "SimpleTypeRule" +typerule RELATION_AGGREGATE ::cvc5::internal::theory::sets::RelationAggregateTypeRule construle SET_UNION ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule construle SET_SINGLETON ::cvc5::internal::theory::sets::SingletonTypeRule diff --git a/src/theory/sets/rels_utils.cpp b/src/theory/sets/rels_utils.cpp index 08d4feb36..6f6c013dc 100644 --- a/src/theory/sets/rels_utils.cpp +++ b/src/theory/sets/rels_utils.cpp @@ -20,6 +20,7 @@ #include "theory/datatypes/project_op.h" #include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" +#include "theory/sets/set_reduction.h" using namespace cvc5::internal::kind; using namespace cvc5::internal::theory::datatypes; @@ -151,6 +152,19 @@ Node RelsUtils::evaluateGroup(TNode n) return ret; } +Node RelsUtils::evaluateRelationAggregate(TNode n) +{ + Assert(n.getKind() == RELATION_AGGREGATE); + if (!(n[1].isConst() && n[2].isConst())) + { + // we can't proceed further. + return n; + } + + Node reduction = SetReduction::reduceAggregateOperator(n); + return reduction; +} + } // namespace sets } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/sets/rels_utils.h b/src/theory/sets/rels_utils.h index 559ef5281..3df5a2879 100644 --- a/src/theory/sets/rels_utils.h +++ b/src/theory/sets/rels_utils.h @@ -72,6 +72,13 @@ class RelsUtils * projection with indices n_1 ... n_k */ static Node evaluateGroup(TNode n); + + /** + * @param n has the form ((_ rel.aggr n1 ... n_k) f initial A) + * where initial and A are constants + * @return the aggregation result. + */ + static Node evaluateRelationAggregate(TNode n); }; } // namespace sets } // namespace theory diff --git a/src/theory/sets/set_reduction.cpp b/src/theory/sets/set_reduction.cpp index 78b5b8036..f8c4ed4ae 100644 --- a/src/theory/sets/set_reduction.cpp +++ b/src/theory/sets/set_reduction.cpp @@ -18,7 +18,7 @@ #include "expr/bound_var_manager.h" #include "expr/emptyset.h" #include "expr/skolem_manager.h" -#include "theory/datatypes/tuple_utils.h" +#include "theory/datatypes//project_op.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "util/rational.h" @@ -120,6 +120,30 @@ Node SetReduction::reduceFoldOperator(Node node, std::vector& asserts) return combine_n; } +Node SetReduction::reduceAggregateOperator(Node node) +{ + Assert(node.getKind() == RELATION_AGGREGATE); + NodeManager* nm = NodeManager::currentNM(); + BoundVarManager* bvm = nm->getBoundVarManager(); + Node function = node[0]; + TypeNode elementType = function.getType().getArgTypes()[0]; + Node initialValue = node[1]; + Node A = node[2]; + + ProjectOp op = node.getOperator().getConst(); + Node groupOp = nm->mkConst(RELATION_GROUP_OP, op); + Node group = nm->mkNode(RELATION_GROUP, {groupOp, A}); + + Node set = bvm->mkBoundVar( + group, "set", nm->mkSetType(elementType)); + Node foldList = nm->mkNode(BOUND_VAR_LIST, set); + Node foldBody = nm->mkNode(SET_FOLD, function, initialValue, set); + + Node fold = nm->mkNode(LAMBDA, foldList, foldBody); + Node map = nm->mkNode(SET_MAP, fold, group); + return map; +} + } // namespace sets } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/sets/set_reduction.h b/src/theory/sets/set_reduction.h index 43172012b..3b0c4526e 100644 --- a/src/theory/sets/set_reduction.h +++ b/src/theory/sets/set_reduction.h @@ -64,6 +64,16 @@ class SetReduction * unionFn: Int -> (Set T1) is an uninterpreted function */ static Node reduceFoldOperator(Node node, std::vector& asserts); + + /** + * @param node of the form ((_ rel.aggr n1 ... nk) f initial A)) + * @return reduction term that uses map, fold, and group operators + * as follows: + * (set.map + * (lambda ((B Table)) (set.fold f initial B)) + * ((_ rel.group n1 ... nk) A)) + */ + static Node reduceAggregateOperator(Node node); }; } // namespace sets diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index 820f33e3b..b09080186 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -168,6 +168,11 @@ TrustNode TheorySets::ppRewrite(TNode n, std::vector& lems) d_im.lemma(andNode, InferenceId::BAGS_FOLD); return TrustNode::mkTrustRewrite(n, ret, nullptr); } + if (nk == TABLE_AGGREGATE) + { + Node ret = SetReduction::reduceAggregateOperator(n); + return TrustNode::mkTrustRewrite(ret, ret, nullptr); + } return d_internal->ppRewrite(n, lems); } diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 5fcd15dd1..e9b3b31b0 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -590,6 +590,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { } case RELATION_GROUP: return postRewriteGroup(node); + case RELATION_AGGREGATE: return postRewriteAggregate(node); default: break; } @@ -765,6 +766,21 @@ RewriteResponse TheorySetsRewriter::postRewriteGroup(TNode n) return RewriteResponse(REWRITE_DONE, n); } +RewriteResponse TheorySetsRewriter::postRewriteAggregate(TNode n) +{ + Assert(n.getKind() == kind::RELATION_AGGREGATE); + if (n[1].isConst() && n[2].isConst()) + { + Node ret = RelsUtils::evaluateRelationAggregate(n); + if (ret != n) + { + return RewriteResponse(REWRITE_AGAIN_FULL, ret); + } + } + + return RewriteResponse(REWRITE_DONE, n); +} + } // namespace sets } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h index dc4f64762..2d321f11c 100644 --- a/src/theory/sets/theory_sets_rewriter.h +++ b/src/theory/sets/theory_sets_rewriter.h @@ -69,48 +69,55 @@ class TheorySetsRewriter : public TheoryRewriter // often this will suffice return postRewrite(equality).d_node; } -private: + + private: /** * Returns true if elementTerm is in setTerm, where both terms are constants. */ bool checkConstantMembership(TNode elementTerm, TNode setTerm); - /** - * rewrites for n include: - * - (set.map f (as set.empty (Set T1)) = (as set.empty (Set T2)) - * - (set.map f (set.singleton x)) = (set.singleton (apply f x)) - * - (set.map f (set.union A B)) = - * (set.union (set.map f A) (set.map f B)) - * where f: T1 -> T2 - */ - RewriteResponse postRewriteMap(TNode n); + /** + * rewrites for n include: + * - (set.map f (as set.empty (Set T1)) = (as set.empty (Set T2)) + * - (set.map f (set.singleton x)) = (set.singleton (apply f x)) + * - (set.map f (set.union A B)) = + * (set.union (set.map f A) (set.map f B)) + * where f: T1 -> T2 + */ + RewriteResponse postRewriteMap(TNode n); - /** - * rewrites for n include: - * - (set.filter p (as set.empty (Set T)) = (as set.empty (Set T)) - * - (set.filter p (set.singleton x)) = - * (ite (p x) (set.singleton x) (as set.empty (Set T))) - * - (set.filter p (set.union A B)) = - * (set.union (set.filter p A) (set.filter p B)) - * where p: T -> Bool - */ - RewriteResponse postRewriteFilter(TNode n); - /** - * rewrites for n include: - * - (set.fold f t (as set.empty (Set T))) = t - * - (set.fold f t (set.singleton x)) = (f t x) - * - (set.fold f t (set.union A B)) = (set.fold f (set.fold f t A) B)) - * where f: T -> S -> S, and t : S - */ - RewriteResponse postRewriteFold(TNode n); - /** - * rewrites for n include: - * - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) = - * (rel.singleton (as set.empty (Relation T) )) - * - ((_ rel.group n1 ... nk) (set.singleton x)) = - * (set.singleton (set.singleton x)) - * - Evaluation of ((_ rel.group n1 ... nk) A) when A is a constant - */ - RewriteResponse postRewriteGroup(TNode n); + /** + * rewrites for n include: + * - (set.filter p (as set.empty (Set T)) = (as set.empty (Set T)) + * - (set.filter p (set.singleton x)) = + * (ite (p x) (set.singleton x) (as set.empty (Set T))) + * - (set.filter p (set.union A B)) = + * (set.union (set.filter p A) (set.filter p B)) + * where p: T -> Bool + */ + RewriteResponse postRewriteFilter(TNode n); + /** + * rewrites for n include: + * - (set.fold f t (as set.empty (Set T))) = t + * - (set.fold f t (set.singleton x)) = (f t x) + * - (set.fold f t (set.union A B)) = (set.fold f (set.fold f t A) B)) + * where f: T -> S -> S, and t : S + */ + RewriteResponse postRewriteFold(TNode n); + /** + * rewrites for n include: + * - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) = + * (rel.singleton (as set.empty (Relation T) )) + * - ((_ rel.group n1 ... nk) (set.singleton x)) = + * (set.singleton (set.singleton x)) + * - Evaluation of ((_ rel.group n1 ... nk) A) when A is a constant + */ + RewriteResponse postRewriteGroup(TNode n); + /** + * @param n has the form ((_ rel.aggr n1 ... n_k) f initial A) + * where initial and A are constants + * @return the aggregation result. + */ + RewriteResponse postRewriteAggregate(TNode n); }; /* class TheorySetsRewriter */ } // namespace sets diff --git a/src/theory/sets/theory_sets_type_rules.cpp b/src/theory/sets/theory_sets_type_rules.cpp index 8a163489a..f9dd7d390 100644 --- a/src/theory/sets/theory_sets_type_rules.cpp +++ b/src/theory/sets/theory_sets_type_rules.cpp @@ -27,6 +27,8 @@ namespace cvc5::internal { namespace theory { namespace sets { +using namespace cvc5::internal::theory::datatypes; + TypeNode SetsBinaryOperatorTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) @@ -612,6 +614,72 @@ TypeNode RelationGroupTypeRule::computeType(NodeManager* nm, TNode n, bool check return nm->mkSetType(setType); } +TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::RELATION_AGGREGATE && n.hasOperator() + && n.getOperator().getKind() == kind::RELATION_AGGREGATE_OP); + ProjectOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + + TypeNode functionType = n[0].getType(check); + TypeNode initialValueType = n[1].getType(check); + TypeNode setType = n[2].getType(check); + + if (check) + { + if (!setType.isSet()) + { + std::stringstream ss; + ss << "RELATION_PROJECT operator expects a table. Found '" << n[2] + << "' of type '" << setType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode tupleType = setType.getSetElementType(); + if (!tupleType.isTuple()) + { + std::stringstream ss; + ss << "TABLE_PROJECT operator expects a table. Found '" << n[2] + << "' of type '" << setType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TupleUtils::checkTypeIndices(n, tupleType, indices); + + TypeNode elementType = setType.getSetElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T T) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + TypeNode rangeType = functionType.getRangeType(); + if (!(argTypes.size() == 2 && argTypes[0] == elementType + && argTypes[1] == rangeType)) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T T). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + if (rangeType != initialValueType) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects an initial value of type " + << rangeType << ". Found a term of type '" << initialValueType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return nm->mkSetType(functionType.getRangeType()); +} + Cardinality SetsProperties::computeCardinality(TypeNode type) { Assert(type.getKind() == kind::SET_TYPE); diff --git a/src/theory/sets/theory_sets_type_rules.h b/src/theory/sets/theory_sets_type_rules.h index 551513fe6..ed973669e 100644 --- a/src/theory/sets/theory_sets_type_rules.h +++ b/src/theory/sets/theory_sets_type_rules.h @@ -223,6 +223,19 @@ struct RelationGroupTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct RelationGroupTypeRule */ +/** + * Relation aggregate operator is indexed by a list of indices (n_1, ..., n_k). + * It ensures that it has 3 arguments: + * - A combining function of type (-> (Tuple T_1 ... T_j) T T) + * - Initial value of type T + * - A relation of type (Relation T_1 ... T_j) where 0 <= n_1, ..., n_k < j + * the returned type is (Relation T). + */ +struct RelationAggregateTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct RelationAggregateTypeRule */ + struct SetsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 2ecf50241..75f6a2929 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -2505,6 +2505,7 @@ set(regress_1_tests regress1/sets/proj-issue164.smt2 regress1/sets/proj-issue178.smt2 regress1/sets/proj-issue494-finite-leafof.smt2 + regress1/sets/relation_aggregate1.smt2 regress1/sets/relation_group1.smt2 regress1/sets/relation_group2.smt2 regress1/sets/relation_group3.smt2 diff --git a/test/regress/cli/regress1/sets/relation_aggregate1.smt2 b/test/regress/cli/regress1/sets/relation_aggregate1.smt2 new file mode 100644 index 000000000..c4677901e --- /dev/null +++ b/test/regress/cli/regress1/sets/relation_aggregate1.smt2 @@ -0,0 +1,31 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) + +(define-fun sumByCategory ((x (Tuple String String Int)) (y (Tuple String Int))) (Tuple String Int) + (tuple + ((_ tuple.select 0) x) + (+ ((_ tuple.select 2) x) ((_ tuple.select 1) y)))) + +(declare-fun categorySales () (Set (Tuple String Int))) + +;(define-fun categorySales () (Set (Tuple String Int)) +; (set.union +; (set.singleton (tuple "Software" 5)) +; (set.singleton (tuple "Hardware" 4)))) + +(assert + (= categorySales + ((_ rel.aggr 0) + sumByCategory + (tuple "" 0) + (set.union + (set.singleton (tuple "Software" "win" 1)) + (set.singleton (tuple "Software" "mac" 4)) + (set.singleton (tuple "Hardware" "cpu" 2)) + (set.singleton (tuple "Hardware" "gpu" 2)))))) + +(check-sat)