From 17d5e26a9a0aac458cd8c9a0bf8d99b62efadc52 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Wed, 6 Jul 2022 16:02:23 -0500 Subject: [PATCH] Add rel.project operator to sets (#8929) --- src/api/cpp/cvc5.cpp | 7 ++ src/api/cpp/cvc5_kind.h | 20 ++++++ src/parser/smt2/Smt2.g | 14 ++++ src/parser/smt2/smt2.cpp | 7 +- src/printer/smt2/smt2_printer.cpp | 16 +++++ src/theory/bags/bag_reduction.cpp | 15 +++++ src/theory/bags/bag_reduction.h | 6 ++ src/theory/bags/bags_utils.cpp | 26 +------- src/theory/bags/theory_bags.cpp | 7 +- src/theory/bags/theory_bags_type_rules.h | 6 +- src/theory/sets/kinds | 16 ++++- src/theory/sets/set_reduction.cpp | 15 +++++ src/theory/sets/set_reduction.h | 5 ++ src/theory/sets/theory_sets.cpp | 7 +- src/theory/sets/theory_sets_rewriter.cpp | 13 ++++ src/theory/sets/theory_sets_rewriter.h | 5 ++ src/theory/sets/theory_sets_type_rules.cpp | 64 ++++++++++++++++++- src/theory/sets/theory_sets_type_rules.h | 11 ++++ test/regress/cli/CMakeLists.txt | 3 + .../cli/regress1/bags/table_project2.smt2 | 11 ++++ .../cli/regress1/sets/relation_project1.smt2 | 26 ++++++++ .../cli/regress1/sets/relation_project2.smt2 | 10 +++ test/unit/api/cpp/op_black.cpp | 3 + test/unit/api/java/OpTest.java | 3 + test/unit/api/python/test_op.py | 3 + 25 files changed, 283 insertions(+), 36 deletions(-) create mode 100644 test/regress/cli/regress1/bags/table_project2.smt2 create mode 100644 test/regress/cli/regress1/sets/relation_project1.smt2 create mode 100644 test/regress/cli/regress1/sets/relation_project2.smt2 diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 692eff021..b72f85c44 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -311,6 +311,7 @@ const static std::unordered_map> KIND_ENUM(RELATION_IDEN, internal::Kind::RELATION_IDEN), KIND_ENUM(RELATION_GROUP, internal::Kind::RELATION_GROUP), KIND_ENUM(RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE), + KIND_ENUM(RELATION_PROJECT, internal::Kind::RELATION_PROJECT), /* Bags ------------------------------------------------------------- */ KIND_ENUM(BAG_UNION_MAX, internal::Kind::BAG_UNION_MAX), KIND_ENUM(BAG_UNION_DISJOINT, internal::Kind::BAG_UNION_DISJOINT), @@ -638,6 +639,8 @@ const static std::unordered_map s_op_kinds{ {TUPLE_PROJECT, internal::Kind::TUPLE_PROJECT_OP}, {RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE_OP}, {RELATION_GROUP, internal::Kind::RELATION_GROUP_OP}, + {RELATION_PROJECT, internal::Kind::RELATION_PROJECT_OP}, {TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP}, {TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP}, {TABLE_JOIN, internal::Kind::TABLE_JOIN_OP}, @@ -1959,6 +1963,7 @@ size_t Op::getNumIndicesHelper() const case TUPLE_PROJECT: case RELATION_AGGREGATE: case RELATION_GROUP: + case RELATION_PROJECT: case TABLE_AGGREGATE: case TABLE_GROUP: case TABLE_JOIN: @@ -2121,6 +2126,7 @@ Term Op::getIndexHelper(size_t index) const case TUPLE_PROJECT: case RELATION_AGGREGATE: case RELATION_GROUP: + case RELATION_PROJECT: case TABLE_AGGREGATE: case TABLE_GROUP: case TABLE_JOIN: @@ -6164,6 +6170,7 @@ Op Solver::mkOp(Kind kind, const std::vector& args) const case TUPLE_PROJECT: case RELATION_AGGREGATE: case RELATION_GROUP: + case RELATION_PROJECT: case TABLE_AGGREGATE: case TABLE_GROUP: case TABLE_JOIN: diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 4d4234f03..0220d48af 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -3398,6 +3398,26 @@ enum Kind : int32_t * \endrst */ RELATION_AGGREGATE, + /** + * Relation projection operator extends tuple projection operator to sets. + * + * - Arity: ``1`` + * - ``1:`` Term of relation Sort + * + * - Indices: ``n`` + * - ``1..n:`` Indices of the projection + * + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst + */ + RELATION_PROJECT, /* Bags ------------------------------------------------------------------ */ diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 96421e6ce..ebeabe3f6 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1441,6 +1441,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2] cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, indices); expr = SOLVER->mkTerm(op, {expr}); } + | LPAREN_TOK RELATION_PROJECT_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_PROJECT, indices); + expr = SOLVER->mkTerm(op, {expr}); + } | /* an atomic term (a term with no subterms) */ termAtomic[atomTerm] { expr = atomTerm; } ; @@ -1624,6 +1630,13 @@ identifier[cvc5::ParseOp& p] p.d_kind = cvc5::RELATION_AGGREGATE; p.d_op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, numerals); } + | RELATION_PROJECT_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ rel.project i_1 ... i_n) where + // i_1, ..., i_n are numerals + p.d_kind = cvc5::RELATION_PROJECT; + p.d_op = SOLVER->mkOp(cvc5::RELATION_PROJECT, numerals); + } | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals] { cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName); @@ -2238,6 +2251,7 @@ TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) } TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group'; RELATION_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.group'; RELATION_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.aggr'; +RELATION_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.project'; FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card'; HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->'; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 5f4ae42e4..fb08a0565 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -258,7 +258,9 @@ void Smt2::addSepOperators() { void Smt2::addCoreSymbols() { defineType("Bool", d_solver->getBooleanSort(), true); - defineType("Table", d_solver->mkBagSort(d_solver->mkTupleSort({})), true); + Sort tupleSort = d_solver->mkTupleSort({}); + defineType("Relation", d_solver->mkSetSort(tupleSort), true); + defineType("Table", d_solver->mkBagSort(tupleSort), true); defineVar("true", d_solver->mkTrue(), true); defineVar("false", d_solver->mkFalse(), true); addOperator(cvc5::AND, "and"); @@ -1131,7 +1133,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP - || p.d_kind == cvc5::RELATION_AGGREGATE) + || p.d_kind == cvc5::RELATION_AGGREGATE + || p.d_kind == cvc5::RELATION_PROJECT) { cvc5::Term ret = d_solver->mkTerm(p.d_op, args); Trace("parser") << "applyParseOp: return projection " << ret << std::endl; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 3ffe7641c..ccb93ff3c 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -851,6 +851,21 @@ void Smt2Printer::toStream(std::ostream& out, } return; } + case kind::RELATION_PROJECT: + { + ProjectOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (rel.project A) + out << "rel.project " << n[0] << ")"; + } + else + { + // e.g. ((_ rel.project 2 4 4) A) + out << "(_ rel.project" << op << ") " << n[0] << ")"; + } + return; + } case kind::CONSTRUCTOR_TYPE: { out << n[n.getNumChildren()-1]; @@ -1192,6 +1207,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::RELATION_JOIN_IMAGE: return "rel.join_image"; case kind::RELATION_GROUP: return "rel.group"; case kind::RELATION_AGGREGATE: return "rel.aggr"; + case kind::RELATION_PROJECT: return "rel.project"; // bag theory case kind::BAG_TYPE: return "Bag"; diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp index b2713f882..8099c1342 100644 --- a/src/theory/bags/bag_reduction.cpp +++ b/src/theory/bags/bag_reduction.cpp @@ -226,6 +226,21 @@ Node BagReduction::reduceAggregateOperator(Node node) return map; } +Node BagReduction::reduceProjectOperator(Node n) +{ + Assert(n.getKind() == TABLE_PROJECT); + NodeManager* nm = NodeManager::currentNM(); + Node A = n[0]; + TypeNode elementType = A.getType().getBagElementType(); + ProjectOp projectOp = n.getOperator().getConst(); + Node op = nm->mkConst(TUPLE_PROJECT_OP, projectOp); + Node t = nm->mkBoundVar("t", elementType); + Node projection = nm->mkNode(TUPLE_PROJECT, op, t); + Node lambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, t), projection); + Node setMap = nm->mkNode(BAG_MAP, lambda, A); + return setMap; +} + } // namespace bags } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h index cf391a120..6c7122afe 100644 --- a/src/theory/bags/bag_reduction.h +++ b/src/theory/bags/bag_reduction.h @@ -105,6 +105,12 @@ class BagReduction * ((_ table.group n1 ... nk) A)) */ static Node reduceAggregateOperator(Node node); + /** + * @param n has the form ((table.project n1 ... nk) A) where A has type + * (Bag T) + * @return (bag.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A) + */ + static Node reduceProjectOperator(Node n); }; } // namespace bags diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index 63112bee6..183f96a80 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -1056,30 +1056,8 @@ Node BagsUtils::evaluateGroup(TNode n) Node BagsUtils::evaluateTableProject(TNode n) { Assert(n.getKind() == TABLE_PROJECT); - // Examples - // -------- - // - ((_ table.project 1) (bag (tuple true "a") 4)) = (bag (tuple "a") 4) - // - (table.project (bag.union_disjoint - // (bag (tuple "a") 4) - // (bag (tuple "b") 3))) = (bag tuple 7) - - Node A = n[0]; - - std::map elementsA = BagsUtils::getBagElements(A); - - std::map elements; - std::vector indices = - n.getOperator().getConst().getIndices(); - - for (const auto& [a, countA] : elementsA) - { - Node element = TupleUtils::getTupleProjection(indices, a); - // multiple elements could be projected to the same tuple. - // Zero is the default value for Rational values. - elements[element] += countA; - } - - Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements); + Node bagMap = BagReduction::reduceProjectOperator(n); + Node ret = evaluateBagMap(bagMap); return ret; } diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 1581a091b..e150199ee 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -112,6 +112,12 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl; return TrustNode::mkTrustRewrite(atom, ret, nullptr); } + case kind::TABLE_PROJECT: + { + Node ret = BagReduction::reduceProjectOperator(atom); + Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl; + return TrustNode::mkTrustRewrite(atom, ret, nullptr); + } default: return TrustNode::null(); } } @@ -465,7 +471,6 @@ void TheoryBags::preRegisterTerm(TNode n) case BAG_TO_SET: case BAG_IS_SINGLETON: case BAG_PARTITION: - case TABLE_PROJECT: { std::stringstream ss; ss << "Term of kind " << n.getKind() << " is not supported yet"; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 86fa42822..a498c9443 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -181,13 +181,13 @@ struct TableProductTypeRule /** * Table project is indexed by a list of indices (n_1, ..., n_m). It ensures * that the argument is a bag of tuples whose arity k is greater than each n_i - * for i = 1, ..., m. If the argument is of type (Bag (Tuple T_1 ... T_k)), then - * the returned type is (Bag (Tuple T_{n_1} ... T_{n_m})). + * for i = 1, ..., m. If the argument is of type (Table T_1 ... T_k), then + * the returned type is (Table T_{n_1} ... T_{n_m}). */ struct TableProjectTypeRule { static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); -}; /* struct BagFoldTypeRule */ +}; /* struct TableProjectTypeRule */ /** * Table aggregate operator is indexed by a list of indices (n_1, ..., n_k). diff --git a/src/theory/sets/kinds b/src/theory/sets/kinds index 4c5899e96..07db0da7e 100644 --- a/src/theory/sets/kinds +++ b/src/theory/sets/kinds @@ -109,6 +109,16 @@ constant RELATION_AGGREGATE_OP \ parameterized RELATION_AGGREGATE RELATION_AGGREGATE_OP 3 "relation aggregate" +# rel.project operator extends datatypes tuple_project operator to a set of tuples +constant RELATION_PROJECT_OP \ + class \ + ProjectOp+ \ + ::cvc5::internal::ProjectOpHashFunction \ + "theory/datatypes/project_op.h" \ + "operator for RELATION_PROJECT; payload is an instance of the cvc5::internal::ProjectOp class" + +parameterized RELATION_PROJECT RELATION_PROJECT_OP 1 "relation projection" + operator RELATION_JOIN 2 "relation join" operator RELATION_PRODUCT 2 "relation cartesian product" operator RELATION_TRANSPOSE 1 "relation transpose" @@ -136,14 +146,16 @@ typerule SET_FOLD ::cvc5::internal::theory::sets::SetFoldTypeRule typerule RELATION_JOIN ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule typerule RELATION_PRODUCT ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule -typerule RELATION_TRANSPOSE ::cvc5::internal::theory::sets::RelTransposeTypeRule +typerule RELATION_TRANSPOSE ::cvc5::internal::theory::sets::RelTransposeTypeRule typerule RELATION_TCLOSURE ::cvc5::internal::theory::sets::RelTransClosureTypeRule -typerule RELATION_JOIN_IMAGE ::cvc5::internal::theory::sets::JoinImageTypeRule +typerule RELATION_JOIN_IMAGE ::cvc5::internal::theory::sets::JoinImageTypeRule typerule RELATION_IDEN ::cvc5::internal::theory::sets::RelIdenTypeRule typerule RELATION_GROUP_OP "SimpleTypeRule" typerule RELATION_GROUP ::cvc5::internal::theory::sets::RelationGroupTypeRule typerule RELATION_AGGREGATE_OP "SimpleTypeRule" typerule RELATION_AGGREGATE ::cvc5::internal::theory::sets::RelationAggregateTypeRule +typerule RELATION_PROJECT_OP "SimpleTypeRule" +typerule RELATION_PROJECT ::cvc5::internal::theory::sets::RelationProjectTypeRule construle SET_UNION ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule construle SET_SINGLETON ::cvc5::internal::theory::sets::SingletonTypeRule diff --git a/src/theory/sets/set_reduction.cpp b/src/theory/sets/set_reduction.cpp index f8c4ed4ae..988374d6a 100644 --- a/src/theory/sets/set_reduction.cpp +++ b/src/theory/sets/set_reduction.cpp @@ -144,6 +144,21 @@ Node SetReduction::reduceAggregateOperator(Node node) return map; } +Node SetReduction::reduceProjectOperator(Node n) +{ + Assert(n.getKind() == RELATION_PROJECT); + NodeManager* nm = NodeManager::currentNM(); + Node A = n[0]; + TypeNode elementType = A.getType().getSetElementType(); + ProjectOp projectOp = n.getOperator().getConst(); + Node op = nm->mkConst(TUPLE_PROJECT_OP, projectOp); + Node t = nm->mkBoundVar("t", elementType); + Node projection = nm->mkNode(TUPLE_PROJECT, op, t); + Node lambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, t), projection); + Node setMap = nm->mkNode(SET_MAP, lambda, A); + return setMap; +} + } // namespace sets } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/sets/set_reduction.h b/src/theory/sets/set_reduction.h index 3b0c4526e..38a5740c1 100644 --- a/src/theory/sets/set_reduction.h +++ b/src/theory/sets/set_reduction.h @@ -74,6 +74,11 @@ class SetReduction * ((_ rel.group n1 ... nk) A)) */ static Node reduceAggregateOperator(Node node); + /** + * @param n has the form ((rel.project n1 ... nk) A) where A has type (Set T) + * @return (set.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A) + */ + static Node reduceProjectOperator(Node n); }; } // namespace sets diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index b09080186..b1ae7a8eb 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -168,11 +168,16 @@ TrustNode TheorySets::ppRewrite(TNode n, std::vector& lems) d_im.lemma(andNode, InferenceId::BAGS_FOLD); return TrustNode::mkTrustRewrite(n, ret, nullptr); } - if (nk == TABLE_AGGREGATE) + if (nk == RELATION_AGGREGATE) { Node ret = SetReduction::reduceAggregateOperator(n); return TrustNode::mkTrustRewrite(ret, ret, nullptr); } + if (nk == RELATION_PROJECT) + { + Node ret = SetReduction::reduceProjectOperator(n); + return TrustNode::mkTrustRewrite(ret, ret, nullptr); + } return d_internal->ppRewrite(n, lems); } diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index e9b3b31b0..cbb65b876 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -22,6 +22,7 @@ #include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" #include "theory/sets/rels_utils.h" +#include "theory/sets/set_reduction.h" #include "util/rational.h" using namespace cvc5::internal::kind; @@ -591,6 +592,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { case RELATION_GROUP: return postRewriteGroup(node); case RELATION_AGGREGATE: return postRewriteAggregate(node); + case RELATION_PROJECT: return postRewriteProject(node); default: break; } @@ -781,6 +783,17 @@ RewriteResponse TheorySetsRewriter::postRewriteAggregate(TNode n) return RewriteResponse(REWRITE_DONE, n); } +RewriteResponse TheorySetsRewriter::postRewriteProject(TNode n) +{ + Assert(n.getKind() == RELATION_PROJECT); + Node ret = SetReduction::reduceProjectOperator(n); + if (ret != n) + { + return RewriteResponse(REWRITE_AGAIN_FULL, ret); + } + return RewriteResponse(REWRITE_DONE, n); +} + } // namespace sets } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h index 2d321f11c..b98e3a531 100644 --- a/src/theory/sets/theory_sets_rewriter.h +++ b/src/theory/sets/theory_sets_rewriter.h @@ -118,6 +118,11 @@ class TheorySetsRewriter : public TheoryRewriter * @return the aggregation result. */ RewriteResponse postRewriteAggregate(TNode n); + /** + * If A has type (Set T), then rewrite ((rel.project n1 ... nk) A) as + * (set.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A) + */ + RewriteResponse postRewriteProject(TNode n); }; /* class TheorySetsRewriter */ } // namespace sets diff --git a/src/theory/sets/theory_sets_type_rules.cpp b/src/theory/sets/theory_sets_type_rules.cpp index f9dd7d390..83f56b90a 100644 --- a/src/theory/sets/theory_sets_type_rules.cpp +++ b/src/theory/sets/theory_sets_type_rules.cpp @@ -15,9 +15,10 @@ #include "theory/sets/theory_sets_type_rules.h" -#include #include +#include "expr/dtype.h" +#include "expr/dtype_cons.h" #include "theory/sets/normal_form.h" #include "util/cardinality.h" #include "theory/datatypes/project_op.h" @@ -632,7 +633,7 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm, if (!setType.isSet()) { std::stringstream ss; - ss << "RELATION_PROJECT operator expects a table. Found '" << n[2] + ss << "RELATION_AGGREGATE operator expects a set. Found '" << n[2] << "' of type '" << setType << "'."; throw TypeCheckingExceptionPrivate(n, ss.str()); } @@ -641,7 +642,7 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm, if (!tupleType.isTuple()) { std::stringstream ss; - ss << "TABLE_PROJECT operator expects a table. Found '" << n[2] + ss << "RELATION_AGGREGATE operator expects a relation. Found '" << n[2] << "' of type '" << setType << "'."; throw TypeCheckingExceptionPrivate(n, ss.str()); } @@ -680,6 +681,63 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm, return nm->mkSetType(functionType.getRangeType()); } +TypeNode RelationProjectTypeRule::computeType(NodeManager* nm, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::RELATION_PROJECT && n.hasOperator() + && n.getOperator().getKind() == kind::RELATION_PROJECT_OP); + ProjectOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + TypeNode setType = n[0].getType(check); + if (check) + { + if (n.getNumChildren() != 1) + { + std::stringstream ss; + ss << "operands in term " << n << " are " << n.getNumChildren() + << ", but RELATION_PROJECT expects 1 operand."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + if (!setType.isSet()) + { + std::stringstream ss; + ss << "RELATION_PROJECT operator expects a set. Found '" << n[0] + << "' of type '" << setType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode tupleType = setType.getSetElementType(); + if (!tupleType.isTuple()) + { + std::stringstream ss; + ss << "RELATION_PROJECT operator expects a relation. Found '" << n[0] + << "' of type '" << setType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + // make sure all indices are less than the length of the tuple type + DType dType = tupleType.getDType(); + DTypeConstructor constructor = dType[0]; + size_t numArgs = constructor.getNumArgs(); + for (uint32_t index : indices) + { + std::stringstream ss; + if (index >= numArgs) + { + ss << "Index " << index << " in term " << n << " is >= " << numArgs + << " which is the number of columns in " << n[0] << "."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + } + TypeNode tupleType = setType.getSetElementType(); + TypeNode retTupleType = + TupleUtils::getTupleProjectionType(indices, tupleType); + return nm->mkSetType(retTupleType); +} + Cardinality SetsProperties::computeCardinality(TypeNode type) { Assert(type.getKind() == kind::SET_TYPE); diff --git a/src/theory/sets/theory_sets_type_rules.h b/src/theory/sets/theory_sets_type_rules.h index ed973669e..9755961c8 100644 --- a/src/theory/sets/theory_sets_type_rules.h +++ b/src/theory/sets/theory_sets_type_rules.h @@ -223,6 +223,17 @@ struct RelationGroupTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct RelationGroupTypeRule */ +/** + * Relation project is indexed by a list of indices (n_1, ..., n_m). It ensures + * that the argument is a set of tuples whose arity k is greater than each n_i + * for i = 1, ..., m. If the argument is of type (Relation T_1 ... T_k), then + * the returned type is (Relation T_{n_1} ... T_{n_m}). + */ +struct RelationProjectTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct RelationProjectTypeRule */ + /** * Relation aggregate operator is indexed by a list of indices (n_1, ..., n_k). * It ensures that it has 3 arguments: diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index bfb23057f..c8efaad52 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -1849,6 +1849,7 @@ set(regress_1_tests regress1/bags/table_join2.smt2 regress1/bags/table_join3.smt2 regress1/bags/table_project1.smt2 + regress1/bags/table_project2.smt2 regress1/bags/union_disjoint.smt2 regress1/bags/union_max1.smt2 regress1/bags/union_max2.smt2 @@ -2511,6 +2512,8 @@ set(regress_1_tests regress1/sets/relation_group3.smt2 regress1/sets/relation_group4.smt2 regress1/sets/relation_group5.smt2 + regress1/sets/relation_project1.smt2 + regress1/sets/relation_project2.smt2 regress1/sets/remove_check_free_31_6.smt2 regress1/sets/sets-disequal.smt2 regress1/sets/sets-tuple-poly.cvc.smt2 diff --git a/test/regress/cli/regress1/bags/table_project2.smt2 b/test/regress/cli/regress1/bags/table_project2.smt2 new file mode 100644 index 000000000..1af4e1e4a --- /dev/null +++ b/test/regress/cli/regress1/bags/table_project2.smt2 @@ -0,0 +1,11 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) + +(declare-fun A () (Table String String)) +(declare-fun B () (Table String String)) + +(assert (= B ((_ table.project 1 0) A))) +(assert (bag.member (tuple "y" "x") B)) +(check-sat) diff --git a/test/regress/cli/regress1/sets/relation_project1.smt2 b/test/regress/cli/regress1/sets/relation_project1.smt2 new file mode 100644 index 000000000..7f1dea385 --- /dev/null +++ b/test/regress/cli/regress1/sets/relation_project1.smt2 @@ -0,0 +1,26 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(declare-fun A () (Relation String Int String Bool)) +(declare-fun B () (Relation Int Bool String String)) +(declare-fun C () (Relation String String)) +(declare-fun D () Relation) + +(assert + (= A + (set.union + (set.singleton (tuple "x" 0 "y" false)) + (set.singleton (tuple "x" 1 "z" true))))) + + +; (set.union (set.singleton (tuple 0 false "x" "y")) (set.singleton (tuple 1 true "x" "z"))) +(assert (= B ((_ rel.project 1 3 0 2) A))) + +; (set.singleton (tuple "x" "x")) +(assert (= C ((_ rel.project 0 0) A))) + +; (set.singleton tuple) +(assert (= D (rel.project A))) + +(check-sat) diff --git a/test/regress/cli/regress1/sets/relation_project2.smt2 b/test/regress/cli/regress1/sets/relation_project2.smt2 new file mode 100644 index 000000000..ca39f2231 --- /dev/null +++ b/test/regress/cli/regress1/sets/relation_project2.smt2 @@ -0,0 +1,10 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :uf-lazy-ll true) + +(declare-fun A () (Relation String String)) +(declare-fun B () (Relation String String)) + +(assert (= B ((_ rel.project 1 0) A))) +(assert (set.member (tuple "y" "x") B)) +(check-sat) diff --git a/test/unit/api/cpp/op_black.cpp b/test/unit/api/cpp/op_black.cpp index 498c18c00..72e05acd9 100644 --- a/test/unit/api/cpp/op_black.cpp +++ b/test/unit/api/cpp/op_black.cpp @@ -103,6 +103,9 @@ TEST_F(TestApiBlackOp, getNumIndices) Op tupleProject = d_solver.mkOp(TUPLE_PROJECT, indices); ASSERT_EQ(indices.size(), tupleProject.getNumIndices()); + Op relationProject = d_solver.mkOp(RELATION_PROJECT, indices); + ASSERT_EQ(indices.size(), relationProject.getNumIndices()); + Op tableProject = d_solver.mkOp(TABLE_PROJECT, indices); ASSERT_EQ(indices.size(), tableProject.getNumIndices()); } diff --git a/test/unit/api/java/OpTest.java b/test/unit/api/java/OpTest.java index c14518a2f..688c0aaea 100644 --- a/test/unit/api/java/OpTest.java +++ b/test/unit/api/java/OpTest.java @@ -120,6 +120,9 @@ class OpTest Op tupleProject = d_solver.mkOp(TUPLE_PROJECT, indices); assertEquals(6, tupleProject.getNumIndices()); + Op relationProject = d_solver.mkOp(RELATION_PROJECT, indices); + assertEquals(6, relationProject.getNumIndices()); + Op tableProject = d_solver.mkOp(TABLE_PROJECT, indices); assertEquals(6, tableProject.getNumIndices()); } diff --git a/test/unit/api/python/test_op.py b/test/unit/api/python/test_op.py index 4ba607926..5959a53c9 100644 --- a/test/unit/api/python/test_op.py +++ b/test/unit/api/python/test_op.py @@ -105,6 +105,9 @@ def test_get_num_indices(solver): tuple_project_op = solver.mkOp(Kind.TUPLE_PROJECT, *indices) assert len(indices) == tuple_project_op.getNumIndices() + relation_project_op = solver.mkOp(Kind.RELATION_PROJECT, *indices) + assert len(indices) == relation_project_op.getNumIndices() + table_project_op = solver.mkOp(Kind.TABLE_PROJECT, *indices) assert len(indices) == table_project_op.getNumIndices() -- 2.30.2