From: mudathirmahgoub Date: Thu, 5 May 2022 13:41:30 +0000 (-0500) Subject: Add operators table.aggr and table.join (#8681) X-Git-Tag: cvc5-1.0.1~167 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ae61d9b6fbd7966d70a00e543b3a8724ed205a41;p=cvc5.git Add operators table.aggr and table.join (#8681) --- diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 115ecddc9..0c6167963 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -331,6 +331,8 @@ const static std::unordered_map> KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION), KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT), KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT), + KIND_ENUM(TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE), + KIND_ENUM(TABLE_JOIN, internal::Kind::TABLE_JOIN), /* Strings ---------------------------------------------------------- */ KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT), KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP), @@ -649,6 +651,10 @@ const static std::unordered_map& args) const case TABLE_PROJECT: res = mkOpHelper(kind, internal::TableProjectOp(args)); break; + case TABLE_AGGREGATE: + res = mkOpHelper(kind, internal::TableAggregateOp(args)); + break; + case TABLE_JOIN: + res = mkOpHelper(kind, internal::TableJoinOp(args)); + break; default: if (nargs == 0) { diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 8ee2f378c..a7fa5644c 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -3726,10 +3726,10 @@ enum Kind : int32_t * Table projection operator extends tuple projection operator to tables. * * - Arity: ``1`` - * - ``1:`` Term of tuple Sort + * - ``1:`` Term of table Sort * * - Indices: ``n`` - * - ``1..n:`` The table indices to project + * - ``1..n:`` Indices of the projection * * - Create Term of this Kind with: * - Solver::mkTerm(const Op&, const std::vector&) const @@ -3738,6 +3738,58 @@ enum Kind : int32_t * - Solver::mkOp(Kind, const std::vector&) const */ TABLE_PROJECT, + /** + * Table aggregate operator has the form + * :math:`((\_ \; table.aggr \; n_1 ... n_k) f i A)` + * where :math:`n_1, ..., n_k` are natural numbers, + * :math:`f` is a function of type + * :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)`, + * :math:`i` has the type :math:`T`, + * and :math`A` has type :math:`Table T_1 ... T_j`. + * The returned type is :math:`(Bag \; T)`. + * + * This operator aggregates elements in A that have the same tuple projection + * with indices n_1, ..., n_k using the combining function :math:`f`, + * and initial value :math:`i`. + * + * - Arity: ``3`` + * - ``1:`` Term of sort :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)` + * - ``2:`` Term of Sort :math:`T` + * - ``3:`` Term of table sort :math:`Table T_1 ... T_j` + * + * - Indices: ``n`` + * - ``1..n:`` Indices of the projection + * + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + */ + TABLE_AGGREGATE, + /** + * Table join operator has the form + * :math:`((\_ \; table.join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)` + * where * :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers, + * and A, B are tables. + * This operator filters the product of two bags based on the equality of + * projected tuples using indices :math:`m_1, \dots, m_k` in table A, + * and indices :math:`n_1, \dots, n_k` in table B. + * + * - Arity: ``2`` + * - ``1:`` Term of table Sort + * - ``2:`` Term of table Sort + * + * - Indices: ``n`` + * - ``1..n:`` Indices of the projection + * + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + */ + TABLE_JOIN, /* Strings --------------------------------------------------------------- */ /** diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 2a5dcf162..239d36468 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1413,6 +1413,18 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2] cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_PROJECT, indices); expr = SOLVER->mkTerm(op, {expr}); } + | LPAREN_TOK TABLE_AGGREGATE_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, indices); + expr = SOLVER->mkTerm(op, {expr}); + } + | LPAREN_TOK TABLE_JOIN_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_JOIN, indices); + expr = SOLVER->mkTerm(op, {expr}); + } | /* an atomic term (a term with no subterms) */ termAtomic[atomTerm] { expr = atomTerm; } ; @@ -1561,6 +1573,20 @@ identifier[cvc5::ParseOp& p] p.d_kind = cvc5::TABLE_PROJECT; p.d_op = SOLVER->mkOp(cvc5::TABLE_PROJECT, numerals); } + | TABLE_AGGREGATE_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ table.aggr i_1 ... i_n) where + // i_1, ..., i_n are numerals + p.d_kind = cvc5::TABLE_AGGREGATE; + p.d_op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, numerals); + } + | TABLE_JOIN_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ table.join i_1 j_1 ... i_n j_n) where + // i_1, ..., i_n, j_1, ..., j_n are numerals + p.d_kind = cvc5::TABLE_JOIN; + p.d_op = SOLVER->mkOp(cvc5::TABLE_JOIN, numerals); + } | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals] { cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName); @@ -2170,6 +2196,8 @@ CHAR_TOK : { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_STRINGS) }? TUPLE_CONST_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple'; TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple.project'; TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project'; +TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr'; +TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join'; FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card'; HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->'; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index a4a16c214..dad36788e 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -1125,7 +1125,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) Trace("parser") << "applyParseOp: return selector " << ret << std::endl; return ret; } - else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT) + else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT + || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN) { cvc5::Term ret = d_solver->mkTerm(p.d_op, args); Trace("parser") << "applyParseOp: return projection " << ret << std::endl; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 19c67f672..3cda30e18 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -806,6 +806,37 @@ void Smt2Printer::toStream(std::ostream& out, } return; } + case kind::TABLE_AGGREGATE: + { + TableAggregateOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (table.project function initial_value bag) + out << "table.aggr " << n[0] << " " << n[1] << " " << n[2] << ")"; + } + else + { + // e.g. ((_ table.aggr 0) function initial_value bag) + out << "(_ table.aggr" << op << ") " << n[0] << " " << n[1] << " " << n[2] + << ")"; + } + return; + } + case kind::TABLE_JOIN: + { + TableJoinOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (table.join A B) + out << "table.join " << n[0] << " " << n[1] << ")"; + } + else + { + // e.g. ((_ table.project 0 1 2 3) A B) + out << "(_ table.join" << op << ") " << n[0] << " " << n[1] << ")"; + } + return; + } case kind::CONSTRUCTOR_TYPE: { out << n[n.getNumChildren()-1]; @@ -978,7 +1009,7 @@ void Smt2Printer::toStream(std::ostream& out, } } stringstream parens; - + for(size_t i = 0, c = 1; i < n.getNumChildren(); ) { if(toDepth != 0) { toStream(out, n[i], toDepth < 0 ? toDepth : toDepth - c, lbind); @@ -1074,9 +1105,9 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::GEQ: return ">="; case kind::DIVISION: case kind::DIVISION_TOTAL: return "/"; - case kind::INTS_DIVISION_TOTAL: + case kind::INTS_DIVISION_TOTAL: case kind::INTS_DIVISION: return "div"; - case kind::INTS_MODULUS_TOTAL: + case kind::INTS_MODULUS_TOTAL: case kind::INTS_MODULUS: return "mod"; case kind::ABS: return "abs"; case kind::IS_INTEGER: return "is_int"; @@ -1186,6 +1217,8 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_PARTITION: return "bag.partition"; case kind::TABLE_PRODUCT: return "table.product"; case kind::TABLE_PROJECT: return "table.project"; + case kind::TABLE_AGGREGATE: return "table.aggr"; + case kind::TABLE_JOIN: return "table.join"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp index 43e1e7ee9..5e6d71935 100644 --- a/src/theory/bags/bag_reduction.cpp +++ b/src/theory/bags/bag_reduction.cpp @@ -18,6 +18,8 @@ #include "expr/bound_var_manager.h" #include "expr/emptybag.h" #include "expr/skolem_manager.h" +#include "table_project_op.h" +#include "theory/datatypes/tuple_utils.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "util/rational.h" @@ -28,7 +30,7 @@ namespace cvc5::internal { namespace theory { namespace bags { -BagReduction::BagReduction(Env& env) : EnvObj(env) {} +BagReduction::BagReduction() {} BagReduction::~BagReduction() {} @@ -58,74 +60,68 @@ typedef expr::Attribute Node BagReduction::reduceFoldOperator(Node node, std::vector& asserts) { Assert(node.getKind() == BAG_FOLD); - if (d_env.getLogicInfo().isHigherOrder()) - { - NodeManager* nm = NodeManager::currentNM(); - SkolemManager* sm = nm->getSkolemManager(); - Node f = node[0]; - Node t = node[1]; - Node A = node[2]; - Node zero = nm->mkConstInt(Rational(0)); - Node one = nm->mkConstInt(Rational(1)); - // types - TypeNode bagType = A.getType(); - TypeNode elementType = A.getType().getBagElementType(); - TypeNode integerType = nm->integerType(); - TypeNode ufType = nm->mkFunctionType(integerType, elementType); - TypeNode resultType = t.getType(); - TypeNode combineType = nm->mkFunctionType(integerType, resultType); - TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType); - // skolem functions - Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A); - Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A); - Node unionDisjoint = sm->mkSkolemFunction( - SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A); - Node combine = sm->mkSkolemFunction( - SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A}); - - BoundVarManager* bvm = nm->getBoundVarManager(); - Node i = - bvm->mkBoundVar(node, "i", nm->integerType()); - Node iList = nm->mkNode(BOUND_VAR_LIST, i); - Node iMinusOne = nm->mkNode(SUB, i, one); - Node uf_i = nm->mkNode(APPLY_UF, uf, i); - Node combine_0 = nm->mkNode(APPLY_UF, combine, zero); - Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne); - Node combine_i = nm->mkNode(APPLY_UF, combine, i); - Node combine_n = nm->mkNode(APPLY_UF, combine, n); - Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero); - Node unionDisjoint_iMinusOne = - nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne); - Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i); - Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n); - Node combine_0_equal = combine_0.eqNode(t); - Node combine_i_equal = - combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne)); - Node unionDisjoint_0_equal = - unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType))); - Node singleton = nm->mkBag(elementType, uf_i, one); - - Node unionDisjoint_i_equal = unionDisjoint_i.eqNode( - nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne)); - Node interval_i = - nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n)); - - Node body_i = - nm->mkNode(IMPLIES, - interval_i, - nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal)); - Node forAll_i = - quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i); - Node nonNegative = nm->mkNode(GEQ, n, zero); - Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n); - asserts.push_back(forAll_i); - asserts.push_back(combine_0_equal); - asserts.push_back(unionDisjoint_0_equal); - asserts.push_back(unionDisjoint_n_equal); - asserts.push_back(nonNegative); - return combine_n; - } - return Node::null(); + NodeManager* nm = NodeManager::currentNM(); + SkolemManager* sm = nm->getSkolemManager(); + Node f = node[0]; + Node t = node[1]; + Node A = node[2]; + Node zero = nm->mkConstInt(Rational(0)); + Node one = nm->mkConstInt(Rational(1)); + // types + TypeNode bagType = A.getType(); + TypeNode elementType = A.getType().getBagElementType(); + TypeNode integerType = nm->integerType(); + TypeNode ufType = nm->mkFunctionType(integerType, elementType); + TypeNode resultType = t.getType(); + TypeNode combineType = nm->mkFunctionType(integerType, resultType); + TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType); + // skolem functions + Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A); + Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A); + Node unionDisjoint = sm->mkSkolemFunction( + SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A); + Node combine = sm->mkSkolemFunction( + SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A}); + + BoundVarManager* bvm = nm->getBoundVarManager(); + Node i = + bvm->mkBoundVar(node, "i", nm->integerType()); + Node iList = nm->mkNode(BOUND_VAR_LIST, i); + Node iMinusOne = nm->mkNode(SUB, i, one); + Node uf_i = nm->mkNode(APPLY_UF, uf, i); + Node combine_0 = nm->mkNode(APPLY_UF, combine, zero); + Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne); + Node combine_i = nm->mkNode(APPLY_UF, combine, i); + Node combine_n = nm->mkNode(APPLY_UF, combine, n); + Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero); + Node unionDisjoint_iMinusOne = nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne); + Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i); + Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n); + Node combine_0_equal = combine_0.eqNode(t); + Node combine_i_equal = + combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne)); + Node unionDisjoint_0_equal = + unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType))); + Node singleton = nm->mkBag(elementType, uf_i, one); + + Node unionDisjoint_i_equal = unionDisjoint_i.eqNode( + nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne)); + Node interval_i = + nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n)); + + Node body_i = + nm->mkNode(IMPLIES, + interval_i, + nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal)); + Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i); + Node nonNegative = nm->mkNode(GEQ, n, zero); + Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n); + asserts.push_back(forAll_i); + asserts.push_back(combine_0_equal); + asserts.push_back(unionDisjoint_0_equal); + asserts.push_back(unionDisjoint_n_equal); + asserts.push_back(nonNegative); + return combine_n; } Node BagReduction::reduceCardOperator(Node node, std::vector& asserts) @@ -206,6 +202,43 @@ Node BagReduction::reduceCardOperator(Node node, std::vector& asserts) return cardinality_n; } +Node BagReduction::reduceAggregateOperator(Node node) +{ + Assert(node.getKind() == TABLE_AGGREGATE); + NodeManager* nm = NodeManager::currentNM(); + BoundVarManager* bvm = nm->getBoundVarManager(); + Node function = node[0]; + TypeNode elementType = function.getType().getArgTypes()[0]; + Node initialValue = node[1]; + Node A = node[2]; + const std::vector& indices = + node.getOperator().getConst().getIndices(); + + Node t1 = bvm->mkBoundVar(node, "t1", elementType); + Node t2 = bvm->mkBoundVar(node, "t2", elementType); + Node list = nm->mkNode(BOUND_VAR_LIST, t1, t2); + Node body = nm->mkConst(true); + for (uint32_t i : indices) + { + Node select1 = datatypes::TupleUtils::nthElementOfTuple(t1, i); + Node select2 = datatypes::TupleUtils::nthElementOfTuple(t2, i); + Node equal = select1.eqNode(select2); + body = body.andNode(equal); + } + + Node lambda = nm->mkNode(LAMBDA, list, body); + Node partition = nm->mkNode(BAG_PARTITION, lambda, A); + + Node bag = bvm->mkBoundVar( + partition, "bag", nm->mkBagType(elementType)); + Node foldList = nm->mkNode(BOUND_VAR_LIST, bag); + Node foldBody = nm->mkNode(BAG_FOLD, function, initialValue, bag); + + Node fold = nm->mkNode(LAMBDA, foldList, foldBody); + Node map = nm->mkNode(BAG_MAP, fold, partition); + return map; +} + } // namespace bags } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h index 55933b1dc..c3f49b0a4 100644 --- a/src/theory/bags/bag_reduction.h +++ b/src/theory/bags/bag_reduction.h @@ -29,10 +29,10 @@ namespace bags { /** * class for bag reductions */ -class BagReduction : EnvObj +class BagReduction { public: - BagReduction(Env& env); + BagReduction(); ~BagReduction(); /** @@ -64,7 +64,7 @@ class BagReduction : EnvObj * combine: Int -> T2 is an uninterpreted function * unionDisjoint: Int -> (Bag T1) is an uninterpreted function */ - Node reduceFoldOperator(Node node, std::vector& asserts); + static Node reduceFoldOperator(Node node, std::vector& asserts); /** * @param node a term of the form (bag.card A) where A: (Bag T) is a bag @@ -95,9 +95,21 @@ class BagReduction : EnvObj * cardinality: Int -> Int is an uninterpreted function * unionDisjoint: Int -> (Bag T1) is an uninterpreted function */ - Node reduceCardOperator(Node node, std::vector& asserts); - - private: + static Node reduceCardOperator(Node node, std::vector& asserts); + /** + * @param node of the form ((_ table.aggr n1 ... nk) f initial A)) + * @return reduction term that uses map, fold, and partition using + * tuple projection as the equivalence relation as follows: + * (bag.map + * (lambda ((B Table)) (bag.fold f initial B)) + * (bag.partition + * (lambda ((t1 Tuple) (t2 Tuple)) ; equivalence relation + * (= + * ((_ tuple.project n1 ... nk) t1) + * ((_ tuple.project n1 ... nk) t2))) + * A)) + */ + static Node reduceAggregateOperator(Node node); }; } // namespace bags diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index f7864d01a..5118df462 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -79,6 +79,7 @@ void BagSolver::checkBasicOperations() case kind::BAG_FILTER: checkFilter(n); break; case kind::BAG_MAP: checkMap(n); break; case kind::TABLE_PRODUCT: checkProduct(n); break; + case kind::TABLE_JOIN: checkJoin(n); break; default: break; } it++; @@ -335,6 +336,30 @@ void BagSolver::checkProduct(Node n) } } +void BagSolver::checkJoin(Node n) +{ + Assert(n.getKind() == TABLE_JOIN); + const set& elementsA = d_state.getElements(n[0]); + const set& elementsB = d_state.getElements(n[1]); + + for (const Node& e1 : elementsA) + { + for (const Node& e2 : elementsB) + { + InferInfo i = d_ig.joinUp( + n, d_state.getRepresentative(e1), d_state.getRepresentative(e2)); + d_im.lemmaTheoryInference(&i); + } + } + + std::set elements = d_state.getElements(n); + for (const Node& e : elements) + { + InferInfo i = d_ig.joinDown(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } +} + } // namespace bags } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index c0737ef2f..2d6e890df 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -100,6 +100,8 @@ class BagSolver : protected EnvObj void checkFilter(Node n); /** apply inference rules for product operator */ void checkProduct(Node n); + /** apply inference rules for join operator */ + void checkJoin(Node n); /** The solver state object */ SolverState& d_state; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 6b7e49a31..e401551a7 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -68,7 +68,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) } else if (BagsUtils::areChildrenConstants(n)) { - Node value = BagsUtils::evaluate(n); + Node value = BagsUtils::evaluate(d_rewriter, n); response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); } else @@ -95,6 +95,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_FOLD: response = postRewriteFold(n); break; case BAG_PARTITION: response = postRewritePartition(n); break; case TABLE_PRODUCT: response = postRewriteProduct(n); break; + case TABLE_AGGREGATE: response = postRewriteAggregate(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } } @@ -665,6 +666,21 @@ BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const return BagsRewriteResponse(n, Rewrite::NONE); } +BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const +{ + Assert(n.getKind() == kind::TABLE_AGGREGATE); + if (n[1].isConst() && n[2].isConst()) + { + Node ret = BagsUtils::evaluateTableAggregate(d_rewriter, n); + if (ret != n) + { + return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST); + } + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const { Assert(n.getKind() == TABLE_PRODUCT); diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 3c08208a8..711846143 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -247,6 +247,7 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse postRewriteFold(const TNode& n) const; BagsRewriteResponse postRewritePartition(const TNode& n) const; + BagsRewriteResponse postRewriteAggregate(const TNode& n) const; /** * rewrites for n include: * - (bag.product A (as bag.empty T2)) = (as bag.empty T) diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index fd5a98c25..7ee00c432 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -19,6 +19,7 @@ #include "expr/emptybag.h" #include "smt/logic_exception.h" #include "table_project_op.h" +#include "theory/bags/bag_reduction.h" #include "theory/datatypes/tuple_utils.h" #include "theory/rewriter.h" #include "theory/sets/normal_form.h" @@ -118,7 +119,7 @@ bool BagsUtils::areChildrenConstants(TNode n) return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); } -Node BagsUtils::evaluate(TNode n) +Node BagsUtils::evaluate(Rewriter* rewriter, TNode n) { Assert(areChildrenConstants(n)); if (n.isConst()) @@ -144,6 +145,7 @@ Node BagsUtils::evaluate(TNode n) case BAG_FILTER: return evaluateBagFilter(n); case BAG_FOLD: return evaluateBagFold(n); case TABLE_PRODUCT: return evaluateProduct(n); + case TABLE_JOIN: return evaluateJoin(rewriter, n); case TABLE_PROJECT: return evaluateTableProject(n); default: break; } @@ -879,9 +881,22 @@ Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n) return ret; } +Node BagsUtils::evaluateTableAggregate(Rewriter* rewriter, TNode n) +{ + Assert(n.getKind() == TABLE_AGGREGATE); + if (!(n[1].isConst() && n[2].isConst())) + { + // we can't proceed further. + return n; + } + + Node reduction = BagReduction::reduceAggregateOperator(n); + return reduction; +} + Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2) { - Assert(n.getKind() == TABLE_PRODUCT); + Assert(n.getKind() == TABLE_PRODUCT || n.getKind() == TABLE_JOIN); Node A = n[0]; Node B = n[1]; TypeNode typeA = A.getType().getBagElementType(); @@ -925,6 +940,41 @@ Node BagsUtils::evaluateProduct(TNode n) return ret; } +Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n) +{ + Assert(n.getKind() == TABLE_JOIN); + + Node A = n[0]; + Node B = n[1]; + auto [aIndices, bIndices] = splitTableJoinIndices(n); + + std::map elementsA = BagsUtils::getBagElements(A); + std::map elementsB = BagsUtils::getBagElements(B); + + std::map elements; + + for (const auto& [a, countA] : elementsA) + { + Node aProjection = TupleUtils::getTupleProjection(aIndices, a); + aProjection = rewriter->rewrite(aProjection); + Assert(aProjection.isConst()); + for (const auto& [b, countB] : elementsB) + { + Node bProjection = TupleUtils::getTupleProjection(bIndices, b); + bProjection = rewriter->rewrite(bProjection); + Assert(bProjection.isConst()); + if (aProjection == bProjection) + { + Node element = constructProductTuple(n, a, b); + elements[element] = countA * countB; + } + } + } + + Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements); + return ret; +} + Node BagsUtils::evaluateTableProject(TNode n) { Assert(n.getKind() == TABLE_PROJECT); @@ -955,6 +1005,24 @@ Node BagsUtils::evaluateTableProject(TNode n) return ret; } +std::pair, std::vector> +BagsUtils::splitTableJoinIndices(Node n) +{ + Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator() + && n.getOperator().getKind() == kind::TABLE_JOIN_OP); + TableJoinOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + size_t joinSize = indices.size() / 2; + std::vector indices1(joinSize), indices2(joinSize); + + for (size_t i = 0, index = 0; i < joinSize; i += 2, ++index) + { + indices1[index] = indices[i]; + indices2[index] = indices[i + 1]; + } + return std::make_pair(indices1, indices2); +} + } // namespace bags } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h index 21de8e959..23f21371b 100644 --- a/src/theory/bags/bags_utils.h +++ b/src/theory/bags/bags_utils.h @@ -54,7 +54,7 @@ class BagsUtils * evaluate the node n to a constant value. * As a precondition, children of n should be constants. */ - static Node evaluate(TNode n); + static Node evaluate(Rewriter* rewriter, TNode n); /** * get the elements along with their multiplicities in a given bag @@ -94,7 +94,14 @@ class BagsUtils * @param n has the form (bag.partition r A) where A is a constant bag * @return a partition of A based on the equivalence relation r */ - static Node evaluateBagPartition(Rewriter *rewriter, TNode n); + static Node evaluateBagPartition(Rewriter* rewriter, TNode n); + + /** + * @param n has the form ((_ table.aggr n1 ... n_k) f initial A) + * where initial and A are constants + * @return the aggregation result. + */ + static Node evaluateTableAggregate(Rewriter* rewriter, TNode n); /** * @param n has the form (bag.filter p A) where A is a constant bag @@ -117,6 +124,14 @@ class BagsUtils */ static Node evaluateProduct(TNode n); + /** + * @param n of the form ((_ table.join (m_1 n_1 ... m_k n_k) ) A B) where + * A, B are constants + * @return the evaluation of inner joining tables A B on columns (m_1, n_1, + * ..., m_k, n_k) + */ + static Node evaluateJoin(Rewriter* rewriter, TNode n); + /** * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a * constant @@ -124,6 +139,14 @@ class BagsUtils */ static Node evaluateTableProject(TNode n); + /** + * @param n has the form ((_ table.join m1 n1 ... mk nk) A B)) where A, B are + * tables and m1 n1 ... mk nk are indices + * @return the pair <[m1 ... mk], [n1 ... nk]> + */ + static std::pair, std::vector> + splitTableJoinIndices(Node n); + private: /** * a high order helper function that return a constant bag that is the result diff --git a/src/theory/bags/card_solver.cpp b/src/theory/bags/card_solver.cpp index 18a82bb7b..1db567d54 100644 --- a/src/theory/bags/card_solver.cpp +++ b/src/theory/bags/card_solver.cpp @@ -34,7 +34,7 @@ namespace theory { namespace bags { CardSolver::CardSolver(Env& env, SolverState& s, InferenceManager& im) - : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im), d_bagReduction(env) + : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im) { d_nm = NodeManager::currentNM(); d_zero = d_nm->mkConstInt(Rational(0)); @@ -238,7 +238,7 @@ void CardSolver::addChildren(const Node& premise, Node card = d_nm->mkNode(BAG_CARD, parent); std::vector asserts; - Node reduced = d_bagReduction.reduceCardOperator(card, asserts); + Node reduced = BagReduction::reduceCardOperator(card, asserts); asserts.push_back(card.eqNode(reduced)); InferInfo inferInfo(&d_im, InferenceId::BAGS_CARD); inferInfo.d_premises.push_back(premise); diff --git a/src/theory/bags/card_solver.h b/src/theory/bags/card_solver.h index c72e8eb98..8802564e6 100644 --- a/src/theory/bags/card_solver.h +++ b/src/theory/bags/card_solver.h @@ -119,9 +119,6 @@ class CardSolver : protected EnvObj InferenceManager& d_im; NodeManager* d_nm; - /** bag reduction */ - BagReduction d_bagReduction; - /** * A map from bag representatives to sets of bag representatives with the * invariant that each key is the disjoint union of each set in the value. diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 2000443a2..715c32662 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -23,6 +23,7 @@ #include "theory/bags/bags_utils.h" #include "theory/bags/inference_manager.h" #include "theory/bags/solver_state.h" +#include "theory/bags/table_project_op.h" #include "theory/datatypes/tuple_utils.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/uf/equality_engine.h" @@ -590,6 +591,9 @@ InferInfo InferenceGenerator::productUp(Node n, Node e1, Node e2) Node countA = getMultiplicityTerm(e1, A); Node countB = getMultiplicityTerm(e2, B); + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one)); + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one)); + Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag"); Node count = getMultiplicityTerm(tuple, skolem); @@ -625,9 +629,91 @@ InferInfo InferenceGenerator::productDown(Node n, Node e) Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag"); Node count = getMultiplicityTerm(e, skolem); + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one)); + + Node multiply = d_nm->mkNode(MULT, countA, countB); + inferInfo.d_conclusion = count.eqNode(multiply); + + return inferInfo; +} + +InferInfo InferenceGenerator::joinUp(Node n, Node e1, Node e2) +{ + Assert(n.getKind() == TABLE_JOIN); + Node A = n[0]; + Node B = n[1]; + Node tuple = BagsUtils::constructProductTuple(n, e1, e2); + + std::vector aElements = TupleUtils::getTupleElements(e1); + std::vector bElements = TupleUtils::getTupleElements(e2); + const std::vector& indices = + n.getOperator().getConst().getIndices(); + + InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_UP); + + for (size_t i = 0; i < indices.size(); i += 2) + { + Node x = aElements[indices[i]]; + Node y = bElements[indices[i + 1]]; + Node equal = x.eqNode(y); + inferInfo.d_premises.push_back(equal); + } + + Node countA = getMultiplicityTerm(e1, A); + Node countB = getMultiplicityTerm(e2, B); + + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one)); + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one)); + + Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag"); + Node count = getMultiplicityTerm(tuple, skolem); Node multiply = d_nm->mkNode(MULT, countA, countB); inferInfo.d_conclusion = count.eqNode(multiply); + return inferInfo; +} + +InferInfo InferenceGenerator::joinDown(Node n, Node e) +{ + Assert(n.getKind() == TABLE_JOIN); + Assert(e.getType().isSubtypeOf(n.getType().getBagElementType())); + + Node A = n[0]; + Node B = n[1]; + + TypeNode tupleBType = B.getType().getBagElementType(); + TypeNode tupleAType = A.getType().getBagElementType(); + size_t tupleALength = tupleAType.getTupleLength(); + size_t productTupleLength = n.getType().getBagElementType().getTupleLength(); + + std::vector elements = TupleUtils::getTupleElements(e); + Node a = TupleUtils::constructTupleFromElements( + tupleAType, elements, 0, tupleALength - 1); + Node b = TupleUtils::constructTupleFromElements( + tupleBType, elements, tupleALength, productTupleLength - 1); + + InferInfo inferInfo(d_im, InferenceId::TABLES_JOIN_DOWN); + + Node countA = getMultiplicityTerm(a, A); + Node countB = getMultiplicityTerm(b, B); + + Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag"); + Node count = getMultiplicityTerm(e, skolem); + inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one)); + + Node multiply = d_nm->mkNode(MULT, countA, countB); + Node multiplicityConstraint = count.eqNode(multiply); + const std::vector& indices = + n.getOperator().getConst().getIndices(); + Node joinConstraints = d_true; + for (size_t i = 0; i < indices.size(); i += 2) + { + Node x = elements[indices[i]]; + Node y = elements[tupleALength + indices[i + 1]]; + Node equal = x.eqNode(y); + joinConstraints = joinConstraints.andNode(equal); + } + inferInfo.d_conclusion = joinConstraints.andNode(multiplicityConstraint); return inferInfo; } diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index c0143c16c..7d92ea4b0 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -306,28 +306,60 @@ class InferenceGenerator InferInfo filterUpwards(Node n, Node e); /** - * @param n is a (table.product A B) where A, B are bags of tuples + * @param n is a (table.product A B) where A, B are tables * @param e1 an element of the form (tuple a1 ... am) * @param e2 an element of the form (tuple b1 ... bn) * @return an inference that represents the following - * (= - * (bag.count (tuple a1 ... am b1 ... bn) skolem) - * (* (bag.count e1 A) (bag.count e2 B))) + * (=> (and (bag.member e1 A) (bag.member e2 B)) + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count e1 A) (bag.count e2 B)))) * where skolem is a variable equals (bag.product A B) */ InferInfo productUp(Node n, Node e1, Node e2); /** - * @param n is a (table.product A B) where A, B are bags of tuples + * @param n is a (table.product A B) where A, B are tables * @param e an element of the form (tuple a1 ... am b1 ... bn) * @return an inference that represents the following - * (= - * (bag.count (tuple a1 ... am b1 ... bn) skolem) - * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))) + * (=> (bag.member e skolem) + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B)))) * where skolem is a variable equals (bag.product A B) */ InferInfo productDown(Node n, Node e); + /** + * @param n is a ((_ table.join m1 n1 ... mk nk) A B) where A, B are tables + * @param e1 an element of the form (tuple a1 ... am) + * @param e2 an element of the form (tuple b1 ... bn) + * @return an inference that represents the following + * (=> (and + * (bag.member e1 A) + * (bag.member e2 B) + * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk})) + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count e1 A) (bag.count e2 B)))) + * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B) + */ + InferInfo joinUp(Node n, Node e1, Node e2); + + /** + * @param n is a (table.product A B) where A, B are tables + * @param e an element of the form (tuple a1 ... am b1 ... bn) + * @return an inference that represents the following + * (=> (bag.member e skolem) + * (and + * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk}) + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B)))) + * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B) + */ + InferInfo joinDown(Node n, Node e); + /** * @param element of type T * @param bag of type (bag T) diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 1e875e998..ecde27c62 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -132,8 +132,32 @@ constant TABLE_PROJECT_OP \ parameterized TABLE_PROJECT TABLE_PROJECT_OP 1 "table projection" +# table.aggregate operator +constant TABLE_AGGREGATE_OP \ + class \ + TableAggregateOp \ + ::cvc5::internal::TableAggregateOpHashFunction \ + "theory/bags/table_project_op.h" \ + "operator for TABLE_AGGREGATE; payload is an instance of the cvc5::internal::TableAggregateOp class" + +parameterized TABLE_AGGREGATE TABLE_AGGREGATE_OP 3 "table aggregate" + +# table.join operator +constant TABLE_JOIN_OP \ + class \ + TableJoinOp \ + ::cvc5::internal::TableJoinOpHashFunction \ + "theory/bags/table_project_op.h" \ + "operator for TABLE_JOIN; payload is an instance of the cvc5::internal::TableJoinOp class" + +parameterized TABLE_JOIN TABLE_JOIN_OP 2 "table join" + typerule TABLE_PRODUCT ::cvc5::internal::theory::bags::TableProductTypeRule typerule TABLE_PROJECT_OP "SimpleTypeRule" typerule TABLE_PROJECT ::cvc5::internal::theory::bags::TableProjectTypeRule +typerule TABLE_AGGREGATE_OP "SimpleTypeRule" +typerule TABLE_AGGREGATE ::cvc5::internal::theory::bags::TableAggregateTypeRule +typerule TABLE_JOIN_OP "SimpleTypeRule" +typerule TABLE_JOIN ::cvc5::internal::theory::bags::TableJoinTypeRule endtheory diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index 0c634351a..17d1d8f9a 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -26,6 +26,7 @@ const char* toString(Rewrite r) switch (r) { case Rewrite::NONE: return "NONE"; + case Rewrite::AGGREGATE_CONST: return "AGGREGATE_CONST"; case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE"; case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT"; case Rewrite::CARD_BAG_MAKE: return "CARD_BAG_MAKE"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index 461ea8703..eb42e053d 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -31,6 +31,7 @@ namespace bags { enum class Rewrite : uint32_t { NONE, // no rewrite happened + AGGREGATE_CONST, BAG_MAKE_COUNT_NEGATIVE, CARD_DISJOINT, CARD_BAG_MAKE, diff --git a/src/theory/bags/table_project_op.cpp b/src/theory/bags/table_project_op.cpp index 426753d8a..72700be9d 100644 --- a/src/theory/bags/table_project_op.cpp +++ b/src/theory/bags/table_project_op.cpp @@ -22,4 +22,14 @@ TableProjectOp::TableProjectOp(std::vector indices) { } +TableAggregateOp::TableAggregateOp(std::vector indices) + : ProjectOp(std::move(indices)) +{ +} + +TableJoinOp::TableJoinOp(std::vector indices) + : ProjectOp(std::move(indices)) +{ +} + } // namespace cvc5::internal diff --git a/src/theory/bags/table_project_op.h b/src/theory/bags/table_project_op.h index a061537e9..10c45f915 100644 --- a/src/theory/bags/table_project_op.h +++ b/src/theory/bags/table_project_op.h @@ -34,11 +34,41 @@ class TableProjectOp : public ProjectOp }; /* class TableProjectOp */ /** - * Hash function for the TupleProjectOpHashFunction objects. + * Hash function for the TableProjectOpHashFunction objects. */ struct TableProjectOpHashFunction : public ProjectOpHashFunction { -}; /* struct TupleProjectOpHashFunction */ +}; /* struct TableProjectOpHashFunction */ + +class TableAggregateOp : public ProjectOp +{ + public: + explicit TableAggregateOp(std::vector indices); + TableAggregateOp(const TableAggregateOp& op) = default; +}; /* class TableAggregateOp */ + +/** + * Hash function for the TableAggregateOpHashFunction objects. + */ +struct TableAggregateOpHashFunction : public ProjectOpHashFunction +{ +}; /* struct TableAggregateOpHashFunction */ + + +class TableJoinOp : public ProjectOp +{ + public: + explicit TableJoinOp(std::vector indices); + TableJoinOp(const TableJoinOp& op) = default; +}; /* class TableJoinOp */ + +/** + * Hash function for the TableJoinOpHashFunction objects. + */ +struct TableJoinOpHashFunction : public ProjectOpHashFunction +{ +}; /* struct TableJoinOpHashFunction */ + } // namespace cvc5::internal diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index adcf3d468..3fa491a7c 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -41,8 +41,7 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation) d_rewriter(env.getRewriter(), &d_statistics.d_rewrites), d_termReg(env, d_state, d_im), d_solver(env, d_state, d_im, d_termReg), - d_cardSolver(env, d_state, d_im), - d_bagReduction(env) + d_cardSolver(env, d_state, d_im) { // use the official theory state and inference manager objects d_theoryState = &d_state; @@ -83,6 +82,8 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(BAG_PARTITION); d_equalityEngine->addFunctionKind(TABLE_PRODUCT); d_equalityEngine->addFunctionKind(TABLE_PROJECT); + d_equalityEngine->addFunctionKind(TABLE_AGGREGATE); + d_equalityEngine->addFunctionKind(TABLE_JOIN); } TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) @@ -95,7 +96,7 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) case kind::BAG_FOLD: { std::vector asserts; - Node ret = d_bagReduction.reduceFoldOperator(atom, asserts); + Node ret = BagReduction::reduceFoldOperator(atom, asserts); NodeManager* nm = NodeManager::currentNM(); Node andNode = nm->mkNode(AND, asserts); d_im.lemma(andNode, InferenceId::BAGS_FOLD); @@ -104,6 +105,12 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) << andNode << std::endl; return TrustNode::mkTrustRewrite(atom, ret, nullptr); } + case kind::TABLE_AGGREGATE: + { + Node ret = BagReduction::reduceAggregateOperator(atom); + Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl; + return TrustNode::mkTrustRewrite(atom, ret, nullptr); + } default: return TrustNode::null(); } } diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h index 8259a2392..9c95f991e 100644 --- a/src/theory/bags/theory_bags.h +++ b/src/theory/bags/theory_bags.h @@ -130,9 +130,6 @@ class TheoryBags : public Theory /** the main solver for bags */ CardSolver d_cardSolver; - /** bag reduction */ - BagReduction d_bagReduction; - /** The representation of the strategy */ Strategy d_strat; diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index e786a6afc..5ca0f0815 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -33,6 +33,8 @@ namespace cvc5::internal { namespace theory { namespace bags { +using namespace datatypes; + TypeNode BinaryOperatorTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) @@ -530,16 +532,8 @@ TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager, throw TypeCheckingExceptionPrivate(n, ss.str()); } - std::vector productTupleTypes; - std::vector tupleATypes = elementAType.getTupleTypes(); - std::vector tupleBTypes = elementBType.getTupleTypes(); - - productTupleTypes.insert( - productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end()); - productTupleTypes.insert( - productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end()); - - TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes); + TypeNode retTupleType = + TupleUtils::concatTupleTypes(elementAType, elementBType); TypeNode retType = nodeManager->mkBagType(retTupleType); return retType; } @@ -595,7 +589,140 @@ TypeNode TableProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check) } TypeNode tupleType = bagType.getBagElementType(); TypeNode retTupleType = - datatypes::TupleUtils::getTupleProjectionType(indices, tupleType); + TupleUtils::getTupleProjectionType(indices, tupleType); + return nm->mkBagType(retTupleType); +} + +TypeNode TableAggregateTypeRule::computeType(NodeManager* nm, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::TABLE_AGGREGATE && n.hasOperator() + && n.getOperator().getKind() == kind::TABLE_AGGREGATE_OP); + TableAggregateOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + + TypeNode functionType = n[0].getType(check); + TypeNode initialValueType = n[1].getType(check); + TypeNode bagType = n[2].getType(check); + + if (check) + { + if (!bagType.isBag()) + { + std::stringstream ss; + ss << "TABLE_PROJECT operator expects a table. Found '" << n[2] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode tupleType = bagType.getBagElementType(); + if (!tupleType.isTuple()) + { + std::stringstream ss; + ss << "TABLE_PROJECT operator expects a table. Found '" << n[2] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TupleUtils::checkTypeIndices(n, tupleType, indices); + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T T) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + TypeNode rangeType = functionType.getRangeType(); + if (!(argTypes.size() == 2 && argTypes[0] == elementType + && argTypes[1] == rangeType)) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T T). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + if (rangeType != initialValueType) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects an initial value of type " + << rangeType << ". Found a term of type '" << initialValueType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return nm->mkBagType(functionType.getRangeType()); +} + +TypeNode TableJoinTypeRule::computeType(NodeManager* nm, TNode n, bool check) +{ + Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator() + && n.getOperator().getKind() == kind::TABLE_JOIN_OP); + TableJoinOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + Node A = n[0]; + Node B = n[1]; + TypeNode aType = A.getType(); + TypeNode bType = B.getType(); + + if (check) + { + if (!(aType.isBag() && bType.isBag())) + { + std::stringstream ss; + ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '" + << n[1] << "' of types '" << aType << "', '" << bType + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode aTupleType = aType.getBagElementType(); + TypeNode bTupleType = bType.getBagElementType(); + if (!(aTupleType.isTuple() && bTupleType.isTuple())) + { + std::stringstream ss; + ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '" + << n[1] << "' of types '" << aType << "', '" << bType + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + if (indices.size() % 2 != 0) + { + std::stringstream ss; + ss << "TABLE_JOIN operator expects even number of indices. Found " + << indices.size() << " in term " << n; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + auto [aIndices, bIndices] = BagsUtils::splitTableJoinIndices(n); + TupleUtils::checkTypeIndices(n, aTupleType, aIndices); + TupleUtils::checkTypeIndices(n, bTupleType, bIndices); + + // check the types of columns + std::vector aTypes = aTupleType.getTupleTypes(); + std::vector bTypes = bTupleType.getTupleTypes(); + for (uint32_t i = 0; i < aIndices.size(); i++) + { + if (aTypes[aIndices[i]] != bTypes[bIndices[i]]) + { + std::stringstream ss; + ss << "TABLE_JOIN operator expects column " << aIndices[i] + << " in table " << n[0] << " to match column " << bIndices[i] + << " in table " << n[1] << ". But their types are " + << aTypes[aIndices[i]] << " and " << bTypes[bIndices[i]] + << "' respectively. "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + } + TypeNode aTupleType = aType.getBagElementType(); + TypeNode bTupleType = bType.getBagElementType(); + TypeNode retTupleType = TupleUtils::concatTupleTypes(aTupleType, bTupleType); return nm->mkBagType(retTupleType); } diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 04e5bfd04..445d627db 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -160,7 +160,8 @@ struct BagFoldTypeRule }; /* struct BagFoldTypeRule */ /** - * Type rule for (bag.partition r A) to make sure r is a binary operation of type + * Type rule for (bag.partition r A) to make sure r is a binary operation of + * type * (-> T1 T1 Bool), and A is a bag of type (Bag T1) */ struct BagPartitionTypeRule @@ -188,6 +189,33 @@ struct TableProjectTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagFoldTypeRule */ +/** + * Table aggregate operator is indexed by a list of indices (n_1, ..., n_k). + * It ensures that it has 3 arguments: + * - A combining function of type (-> (Tuple T_1 ... T_j) T T) + * - Initial value of type T + * - A table of type (Table T_1 ... T_j) where 0 <= n_1, ..., n_k < j + * the returned type is (Bag T). + */ +struct TableAggregateTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct TableAggregateTypeRule */ + +/** + * Table join operator is indexed by a list of indices (m_1, m_k, n_1, ..., + * n_k). It ensures that it has 2 arguments: + * - A table of type (Table X_1 ... X_i) + * - A table of type (Table Y_1 ... Y_j) + * such that indices has constraints 0 <= m_1, ..., mk, n_1, ..., n_k <= + * min(i,j) and types has constraints X_{m_1} = Y_{n_1}, ..., X_{m_k} = Y_{n_k}. + * The returned type is (Table X_1 ... X_i Y_1 ... Y_j) + */ +struct TableJoinTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct TableJoinTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp index 05528a644..838e840c3 100644 --- a/src/theory/datatypes/tuple_utils.cpp +++ b/src/theory/datatypes/tuple_utils.cpp @@ -15,6 +15,8 @@ #include "tuple_utils.h" +#include + #include "expr/dtype.h" #include "expr/dtype_cons.h" @@ -24,6 +26,41 @@ namespace cvc5::internal { namespace theory { namespace datatypes { +void TupleUtils::checkTypeIndices(Node n, + TypeNode tupleType, + const std::vector indices) +{ + // make sure all indices are less than the size of the tuple + DType dType = tupleType.getDType(); + DTypeConstructor constructor = dType[0]; + size_t numArgs = constructor.getNumArgs(); + for (uint32_t index : indices) + { + std::stringstream ss; + if (index >= numArgs) + { + ss << "Index " << index << " in term " << n << " is > " << (numArgs - 1) + << " the maximum value "; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } +} + +TypeNode TupleUtils::concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2) +{ + std::vector concatTupleTypes; + std::vector tuple1Types = tupleType1.getTupleTypes(); + std::vector tuple2Types = tupleType2.getTupleTypes(); + + concatTupleTypes.insert( + concatTupleTypes.end(), tuple1Types.begin(), tuple1Types.end()); + concatTupleTypes.insert( + concatTupleTypes.end(), tuple2Types.begin(), tuple2Types.end()); + NodeManager* nm = NodeManager::currentNM(); + TypeNode ret = nm->mkTupleType(concatTupleTypes); + return ret; +} + Node TupleUtils::nthElementOfTuple(Node tuple, int n_th) { if (tuple.getKind() == APPLY_CONSTRUCTOR) diff --git a/src/theory/datatypes/tuple_utils.h b/src/theory/datatypes/tuple_utils.h index 041121397..9afbd59fe 100644 --- a/src/theory/datatypes/tuple_utils.h +++ b/src/theory/datatypes/tuple_utils.h @@ -25,6 +25,25 @@ namespace datatypes { class TupleUtils { public: + /** + * + * @param n a node to print in the message if TypeCheckingExceptionPrivate + * exception is thrown + * @param tupleType the type of the tuple + * @param indices a list of indices for projection + * @throw an exception if one of the indices in node n is greater than the + * expected tuple's length + */ + static void checkTypeIndices(Node n, + TypeNode tupleType, + const std::vector indices); + /** + * @param tupleType1 tuple type + * @param tupleType2 tuple type + * @return the type of concatenation of tupleType1, tupleType2 + */ + static TypeNode concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2); + /** * @param tuple a node of tuple type * @param n_th the index of the element to be extracted, and must satisfy the diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index f298c6b85..fe4ff0efc 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -140,6 +140,8 @@ const char* toString(InferenceId i) case InferenceId::BAGS_CARD_EMPTY: return "BAGS_CARD_EMPTY"; case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP"; case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN"; + case InferenceId::TABLES_JOIN_UP: return "TABLES_JOIN_UP"; + case InferenceId::TABLES_JOIN_DOWN: return "TABLES_JOIN_DOWN"; case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT"; case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA: diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 034eeacc1..b56168e72 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -206,6 +206,8 @@ enum class InferenceId BAGS_CARD_EMPTY, TABLES_PRODUCT_UP, TABLES_PRODUCT_DOWN, + TABLES_JOIN_UP, + TABLES_JOIN_DOWN, // ---------------------------------- end bags theory // ---------------------------------- bitvector theory diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 13324b7a4..100138346 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -1813,6 +1813,10 @@ set(regress_1_tests regress1/bags/proj-issue497.smt2 regress1/bags/subbag1.smt2 regress1/bags/subbag2.smt2 + regress1/bags/table_aggregate1.smt2 + regress1/bags/table_join1.smt2 + regress1/bags/table_join2.smt2 + regress1/bags/table_join3.smt2 regress1/bags/table_project1.smt2 regress1/bags/union_disjoint.smt2 regress1/bags/union_max1.smt2 diff --git a/test/regress/cli/regress1/bags/table_aggregate1.smt2 b/test/regress/cli/regress1/bags/table_aggregate1.smt2 new file mode 100644 index 000000000..637def3dc --- /dev/null +++ b/test/regress/cli/regress1/bags/table_aggregate1.smt2 @@ -0,0 +1,31 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) + +(define-fun sumByCategory ((x (Tuple String String Int)) (y (Tuple String Int))) (Tuple String Int) + (tuple + ((_ tuple.select 0) x) + (+ ((_ tuple.select 2) x) ((_ tuple.select 1) y)))) + +(declare-fun categorySales () (Bag (Tuple String Int))) + +;(define-fun categorySales () (Bag (Tuple String Int)) +; (bag.union_disjoint +; (bag (tuple "Hardware" 10) 1) +; (bag (tuple "Software" 6) 1))) + +(assert + (= categorySales + ((_ table.aggr 0) + sumByCategory + (tuple "" 0) + (bag.union_disjoint + (bag (tuple "Software" "win" 1) 2) + (bag (tuple "Software" "mac" 4) 1) + (bag (tuple "Hardware" "cpu" 2) 2) + (bag (tuple "Hardware" "gpu" 3) 2))))) + +(check-sat) diff --git a/test/regress/cli/regress1/bags/table_join1.smt2 b/test/regress/cli/regress1/bags/table_join1.smt2 new file mode 100644 index 000000000..349137f50 --- /dev/null +++ b/test/regress/cli/regress1/bags/table_join1.smt2 @@ -0,0 +1,29 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(declare-fun Departments () (Table Int String)) +(declare-fun Students () (Table Int String Int)) +(declare-fun DepartmentStudents () (Table Int String Int String Int)) + +(assert + (= Departments + (bag.union_disjoint + (bag (tuple 1 "Computer") 1) + (bag (tuple 2 "Engineering") 1)))) + +(assert + (= Students + (bag.union_disjoint + (bag (tuple 1 "A" 1) 1) + (bag (tuple 2 "B" 1) 1) + (bag (tuple 3 "C" 2) 1)))) + +;(define-fun DepartmentStudents () (Bag (Tuple Int String Int String Int)) +; (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1) +; (bag (tuple 1 "Computer" 2 "B" 1) 1) +; (bag (tuple 2 "Engineering" 3 "C" 2) 1))) + +(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students))) + +(check-sat) diff --git a/test/regress/cli/regress1/bags/table_join2.smt2 b/test/regress/cli/regress1/bags/table_join2.smt2 new file mode 100644 index 000000000..2c86e859d --- /dev/null +++ b/test/regress/cli/regress1/bags/table_join2.smt2 @@ -0,0 +1,28 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(declare-fun Departments () (Table Int String)) +(declare-fun Students () (Table Int String Int)) +(declare-fun DepartmentStudents () (Table Int String Int String Int)) + +;(define-fun Departments () (Bag (Tuple Int String)) +; (bag.union_disjoint +; (bag (tuple 1 "Computer") 1) +; (bag (tuple 2 "Engineering") 1))) + +;(define-fun Students () (Bag (Tuple Int String Int)) +; (bag.union_disjoint +; (bag (tuple 1 "A" 1) 1) +; (bag (tuple 2 "B" 1) 1) +; (bag (tuple 3 "C" 2) 1))) + +(assert + (= DepartmentStudents + (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1) + (bag (tuple 1 "Computer" 2 "B" 1) 1) + (bag (tuple 2 "Engineering" 3 "C" 2) 1)))) + +(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students))) + +(check-sat) diff --git a/test/regress/cli/regress1/bags/table_join3.smt2 b/test/regress/cli/regress1/bags/table_join3.smt2 new file mode 100644 index 000000000..aa416d0fa --- /dev/null +++ b/test/regress/cli/regress1/bags/table_join3.smt2 @@ -0,0 +1,27 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(declare-fun Departments () (Table Int String)) +(declare-fun Students () (Table Int String Int)) +(declare-fun DepartmentStudents () (Table Int String Int String Int)) + +(declare-fun d1 () (Tuple Int String)) +(declare-fun d2 () (Tuple Int String)) +(assert (distinct d1 d2)) + +(declare-fun s1 () (Tuple Int String Int)) +(declare-fun s2 () (Tuple Int String Int)) +(assert (distinct s1 s2)) + +(assert + (distinct DepartmentStudents (as bag.empty (Table Int String Int String Int)))) + +(assert (bag.member d1 Departments)) +(assert (bag.member d2 Departments)) +(assert (bag.member s1 Students)) +(assert (bag.member s2 Students)) + +(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students))) + +(check-sat) diff --git a/test/unit/theory/theory_bags_normal_form_white.cpp b/test/unit/theory/theory_bags_normal_form_white.cpp index d144f1107..e6cc722a3 100644 --- a/test/unit/theory/theory_bags_normal_form_white.cpp +++ b/test/unit/theory/theory_bags_normal_form_white.cpp @@ -56,6 +56,8 @@ class TestTheoryWhiteBagsNormalForm : public TestSmt return elements; } + Node rewrite(Node n) { return d_rewriter->postRewrite(n).d_node; } + std::unique_ptr d_rewriter; }; @@ -64,7 +66,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, empty_bag_normal_form) Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType())); // empty bags are in normal form ASSERT_TRUE(emptybag.isConst()); - Node n = BagsUtils::evaluate(emptybag); + Node n = rewrite(emptybag); ASSERT_EQ(emptybag, n); } @@ -85,9 +87,9 @@ TEST_F(TestTheoryWhiteBagsNormalForm, mkBag_constant_element) ASSERT_FALSE(negative.isConst()); ASSERT_FALSE(zero.isConst()); - ASSERT_EQ(emptybag, BagsUtils::evaluate(negative)); - ASSERT_EQ(emptybag, BagsUtils::evaluate(zero)); - ASSERT_EQ(positive, BagsUtils::evaluate(positive)); + ASSERT_EQ(emptybag, rewrite(negative)); + ASSERT_EQ(emptybag, rewrite(zero)); + ASSERT_EQ(positive, rewrite(positive)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) @@ -116,25 +118,25 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty); Node output1 = zero; - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5); Node output2 = zero; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4); Node output3 = four; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY); Node output4 = four; - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5); Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ); Node output5 = zero; - ASSERT_EQ(output4, BagsUtils::evaluate(input4)); + ASSERT_EQ(output4, rewrite(input4)); } TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) @@ -151,7 +153,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag); Node output1 = emptybag; - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -168,12 +170,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4); Node output2 = x_1; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_max) @@ -213,7 +215,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_max) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, BagsUtils::evaluate(input)); + ASSERT_EQ(output, rewrite(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) @@ -234,12 +236,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B); // unionDisjointAB is already in a normal form ASSERT_TRUE(unionDisjointAB.isConst()); - ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB)); + ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointAB)); Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A); // unionDisjointAB is the normal form of unionDisjointBA ASSERT_FALSE(unionDisjointBA.isConst()); - ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA)); + ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointBA)); Node unionDisjointAB_C = d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C); @@ -249,7 +251,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) // unionDisjointA_BC is the normal form of unionDisjointAB_C ASSERT_FALSE(unionDisjointAB_C.isConst()); ASSERT_TRUE(unionDisjointA_BC.isConst()); - ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C)); + ASSERT_EQ(unionDisjointA_BC, rewrite(unionDisjointAB_C)); Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A); Node AA = d_nodeManager->mkBag(d_nodeManager->stringType(), @@ -257,7 +259,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) d_nodeManager->mkConstInt(Rational(4))); ASSERT_FALSE(unionDisjointAA.isConst()); ASSERT_TRUE(AA.isConst()); - ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA)); + ASSERT_EQ(AA, rewrite(unionDisjointAA)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) @@ -297,7 +299,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, BagsUtils::evaluate(input)); + ASSERT_EQ(output, rewrite(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) @@ -332,7 +334,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) Node output = x_3; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, BagsUtils::evaluate(input)); + ASSERT_EQ(output, rewrite(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) @@ -369,7 +371,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, BagsUtils::evaluate(input)); + ASSERT_EQ(output, rewrite(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) @@ -406,7 +408,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) Node output = z_2; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, BagsUtils::evaluate(input)); + ASSERT_EQ(output, rewrite(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) @@ -429,16 +431,16 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) Node input1 = d_nodeManager->mkNode(BAG_CARD, empty); Node output1 = d_nodeManager->mkConstInt(Rational(0)); - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4); Node output2 = d_nodeManager->mkConstInt(Rational(4)); - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1); Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint); Node output3 = d_nodeManager->mkConstInt(Rational(5)); - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) @@ -466,20 +468,20 @@ TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty); Node output1 = falseNode; - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1); Node output2 = trueNode; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4); Node output3 = falseNode; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint); Node output4 = falseNode; - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, from_set) @@ -497,7 +499,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset); Node output1 = emptybag; - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -512,13 +514,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton); Node output2 = x_1; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); // for normal sets, the first node is the largest, not smallest Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, to_set) @@ -536,7 +538,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag); Node output1 = emptyset; - ASSERT_EQ(output1, BagsUtils::evaluate(input1)); + ASSERT_EQ(output1, rewrite(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -551,13 +553,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4); Node output2 = xSingleton; - ASSERT_EQ(output2, BagsUtils::evaluate(input2)); + ASSERT_EQ(output2, rewrite(input2)); // for normal sets, the first node is the largest, not smallest Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag); Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); - ASSERT_EQ(output3, BagsUtils::evaluate(input3)); + ASSERT_EQ(output3, rewrite(input3)); } } // namespace test } // namespace cvc5::internal