KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION),
KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT),
KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
+ KIND_ENUM(TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE),
+ KIND_ENUM(TABLE_JOIN, internal::Kind::TABLE_JOIN),
/* Strings ---------------------------------------------------------- */
KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT),
KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP),
{internal::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
{internal::Kind::TABLE_PROJECT, TABLE_PROJECT},
{internal::Kind::TABLE_PROJECT_OP, TABLE_PROJECT},
+ {internal::Kind::TABLE_AGGREGATE_OP, TABLE_AGGREGATE},
+ {internal::Kind::TABLE_AGGREGATE, TABLE_AGGREGATE},
+ {internal::Kind::TABLE_JOIN_OP, TABLE_JOIN},
+ {internal::Kind::TABLE_JOIN, TABLE_JOIN},
/* Strings --------------------------------------------------------- */
{internal::Kind::STRING_CONCAT, STRING_CONCAT},
{internal::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
case TABLE_PROJECT:
res = mkOpHelper(kind, internal::TableProjectOp(args));
break;
+ case TABLE_AGGREGATE:
+ res = mkOpHelper(kind, internal::TableAggregateOp(args));
+ break;
+ case TABLE_JOIN:
+ res = mkOpHelper(kind, internal::TableJoinOp(args));
+ break;
default:
if (nargs == 0)
{
* Table projection operator extends tuple projection operator to tables.
*
* - Arity: ``1``
- * - ``1:`` Term of tuple Sort
+ * - ``1:`` Term of table Sort
*
* - Indices: ``n``
- * - ``1..n:`` The table indices to project
+ * - ``1..n:`` Indices of the projection
*
* - Create Term of this Kind with:
* - Solver::mkTerm(const Op&, const std::vector<Term>&) const
* - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
*/
TABLE_PROJECT,
+ /**
+ * Table aggregate operator has the form
+ * :math:`((\_ \; table.aggr \; n_1 ... n_k) f i A)`
+ * where :math:`n_1, ..., n_k` are natural numbers,
+ * :math:`f` is a function of type
+ * :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)`,
+ * :math:`i` has the type :math:`T`,
+ * and :math`A` has type :math:`Table T_1 ... T_j`.
+ * The returned type is :math:`(Bag \; T)`.
+ *
+ * This operator aggregates elements in A that have the same tuple projection
+ * with indices n_1, ..., n_k using the combining function :math:`f`,
+ * and initial value :math:`i`.
+ *
+ * - Arity: ``3``
+ * - ``1:`` Term of sort :math:`(\rightarrow (Tuple \; T_1 \; ... \; T_j)\; T \; T)`
+ * - ``2:`` Term of Sort :math:`T`
+ * - ``3:`` Term of table sort :math:`Table T_1 ... T_j`
+ *
+ * - Indices: ``n``
+ * - ``1..n:`` Indices of the projection
+ *
+ * - Create Term of this Kind with:
+ * - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+ *
+ * - Create Op of this kind with:
+ * - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+ */
+ TABLE_AGGREGATE,
+ /**
+ * Table join operator has the form
+ * :math:`((\_ \; table.join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)`
+ * where * :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers,
+ * and A, B are tables.
+ * This operator filters the product of two bags based on the equality of
+ * projected tuples using indices :math:`m_1, \dots, m_k` in table A,
+ * and indices :math:`n_1, \dots, n_k` in table B.
+ *
+ * - Arity: ``2``
+ * - ``1:`` Term of table Sort
+ * - ``2:`` Term of table Sort
+ *
+ * - Indices: ``n``
+ * - ``1..n:`` Indices of the projection
+ *
+ * - Create Term of this Kind with:
+ * - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+ *
+ * - Create Op of this kind with:
+ * - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+ */
+ TABLE_JOIN,
/* Strings --------------------------------------------------------------- */
/**
cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_PROJECT, indices);
expr = SOLVER->mkTerm(op, {expr});
}
+ | LPAREN_TOK TABLE_AGGREGATE_TOK term[expr,expr2] RPAREN_TOK
+ {
+ std::vector<uint32_t> indices;
+ cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, indices);
+ expr = SOLVER->mkTerm(op, {expr});
+ }
+ | LPAREN_TOK TABLE_JOIN_TOK term[expr,expr2] RPAREN_TOK
+ {
+ std::vector<uint32_t> indices;
+ cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_JOIN, indices);
+ expr = SOLVER->mkTerm(op, {expr});
+ }
| /* an atomic term (a term with no subterms) */
termAtomic[atomTerm] { expr = atomTerm; }
;
p.d_kind = cvc5::TABLE_PROJECT;
p.d_op = SOLVER->mkOp(cvc5::TABLE_PROJECT, numerals);
}
+ | TABLE_AGGREGATE_TOK nonemptyNumeralList[numerals]
+ {
+ // we adopt a special syntax (_ table.aggr i_1 ... i_n) where
+ // i_1, ..., i_n are numerals
+ p.d_kind = cvc5::TABLE_AGGREGATE;
+ p.d_op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, numerals);
+ }
+ | TABLE_JOIN_TOK nonemptyNumeralList[numerals]
+ {
+ // we adopt a special syntax (_ table.join i_1 j_1 ... i_n j_n) where
+ // i_1, ..., i_n, j_1, ..., j_n are numerals
+ p.d_kind = cvc5::TABLE_JOIN;
+ p.d_op = SOLVER->mkOp(cvc5::TABLE_JOIN, numerals);
+ }
| functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
{
cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
TUPLE_CONST_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple';
TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple.project';
TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project';
+TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr';
+TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join';
FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
Trace("parser") << "applyParseOp: return selector " << ret << std::endl;
return ret;
}
- else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT)
+ else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
+ || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN)
{
cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
}
return;
}
+ case kind::TABLE_AGGREGATE:
+ {
+ TableAggregateOp op = n.getOperator().getConst<TableAggregateOp>();
+ if (op.getIndices().empty())
+ {
+ // e.g. (table.project function initial_value bag)
+ out << "table.aggr " << n[0] << " " << n[1] << " " << n[2] << ")";
+ }
+ else
+ {
+ // e.g. ((_ table.aggr 0) function initial_value bag)
+ out << "(_ table.aggr" << op << ") " << n[0] << " " << n[1] << " " << n[2]
+ << ")";
+ }
+ return;
+ }
+ case kind::TABLE_JOIN:
+ {
+ TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+ if (op.getIndices().empty())
+ {
+ // e.g. (table.join A B)
+ out << "table.join " << n[0] << " " << n[1] << ")";
+ }
+ else
+ {
+ // e.g. ((_ table.project 0 1 2 3) A B)
+ out << "(_ table.join" << op << ") " << n[0] << " " << n[1] << ")";
+ }
+ return;
+ }
case kind::CONSTRUCTOR_TYPE:
{
out << n[n.getNumChildren()-1];
}
}
stringstream parens;
-
+
for(size_t i = 0, c = 1; i < n.getNumChildren(); ) {
if(toDepth != 0) {
toStream(out, n[i], toDepth < 0 ? toDepth : toDepth - c, lbind);
case kind::GEQ: return ">=";
case kind::DIVISION:
case kind::DIVISION_TOTAL: return "/";
- case kind::INTS_DIVISION_TOTAL:
+ case kind::INTS_DIVISION_TOTAL:
case kind::INTS_DIVISION: return "div";
- case kind::INTS_MODULUS_TOTAL:
+ case kind::INTS_MODULUS_TOTAL:
case kind::INTS_MODULUS: return "mod";
case kind::ABS: return "abs";
case kind::IS_INTEGER: return "is_int";
case kind::BAG_PARTITION: return "bag.partition";
case kind::TABLE_PRODUCT: return "table.product";
case kind::TABLE_PROJECT: return "table.project";
+ case kind::TABLE_AGGREGATE: return "table.aggr";
+ case kind::TABLE_JOIN: return "table.join";
// fp theory
case kind::FLOATINGPOINT_FP: return "fp";
#include "expr/bound_var_manager.h"
#include "expr/emptybag.h"
#include "expr/skolem_manager.h"
+#include "table_project_op.h"
+#include "theory/datatypes/tuple_utils.h"
#include "theory/quantifiers/fmf/bounded_integers.h"
#include "util/rational.h"
namespace theory {
namespace bags {
-BagReduction::BagReduction(Env& env) : EnvObj(env) {}
+BagReduction::BagReduction() {}
BagReduction::~BagReduction() {}
Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
{
Assert(node.getKind() == BAG_FOLD);
- if (d_env.getLogicInfo().isHigherOrder())
- {
- NodeManager* nm = NodeManager::currentNM();
- SkolemManager* sm = nm->getSkolemManager();
- Node f = node[0];
- Node t = node[1];
- Node A = node[2];
- Node zero = nm->mkConstInt(Rational(0));
- Node one = nm->mkConstInt(Rational(1));
- // types
- TypeNode bagType = A.getType();
- TypeNode elementType = A.getType().getBagElementType();
- TypeNode integerType = nm->integerType();
- TypeNode ufType = nm->mkFunctionType(integerType, elementType);
- TypeNode resultType = t.getType();
- TypeNode combineType = nm->mkFunctionType(integerType, resultType);
- TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
- // skolem functions
- Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
- Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
- Node unionDisjoint = sm->mkSkolemFunction(
- SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
- Node combine = sm->mkSkolemFunction(
- SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
-
- BoundVarManager* bvm = nm->getBoundVarManager();
- Node i =
- bvm->mkBoundVar<FirstIndexVarAttribute>(node, "i", nm->integerType());
- Node iList = nm->mkNode(BOUND_VAR_LIST, i);
- Node iMinusOne = nm->mkNode(SUB, i, one);
- Node uf_i = nm->mkNode(APPLY_UF, uf, i);
- Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
- Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
- Node combine_i = nm->mkNode(APPLY_UF, combine, i);
- Node combine_n = nm->mkNode(APPLY_UF, combine, n);
- Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
- Node unionDisjoint_iMinusOne =
- nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
- Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
- Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
- Node combine_0_equal = combine_0.eqNode(t);
- Node combine_i_equal =
- combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
- Node unionDisjoint_0_equal =
- unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
- Node singleton = nm->mkBag(elementType, uf_i, one);
-
- Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
- nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
- Node interval_i =
- nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
-
- Node body_i =
- nm->mkNode(IMPLIES,
- interval_i,
- nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
- Node forAll_i =
- quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
- Node nonNegative = nm->mkNode(GEQ, n, zero);
- Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
- asserts.push_back(forAll_i);
- asserts.push_back(combine_0_equal);
- asserts.push_back(unionDisjoint_0_equal);
- asserts.push_back(unionDisjoint_n_equal);
- asserts.push_back(nonNegative);
- return combine_n;
- }
- return Node::null();
+ NodeManager* nm = NodeManager::currentNM();
+ SkolemManager* sm = nm->getSkolemManager();
+ Node f = node[0];
+ Node t = node[1];
+ Node A = node[2];
+ Node zero = nm->mkConstInt(Rational(0));
+ Node one = nm->mkConstInt(Rational(1));
+ // types
+ TypeNode bagType = A.getType();
+ TypeNode elementType = A.getType().getBagElementType();
+ TypeNode integerType = nm->integerType();
+ TypeNode ufType = nm->mkFunctionType(integerType, elementType);
+ TypeNode resultType = t.getType();
+ TypeNode combineType = nm->mkFunctionType(integerType, resultType);
+ TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
+ // skolem functions
+ Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
+ Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
+ Node unionDisjoint = sm->mkSkolemFunction(
+ SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
+ Node combine = sm->mkSkolemFunction(
+ SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
+
+ BoundVarManager* bvm = nm->getBoundVarManager();
+ Node i =
+ bvm->mkBoundVar<FirstIndexVarAttribute>(node, "i", nm->integerType());
+ Node iList = nm->mkNode(BOUND_VAR_LIST, i);
+ Node iMinusOne = nm->mkNode(SUB, i, one);
+ Node uf_i = nm->mkNode(APPLY_UF, uf, i);
+ Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
+ Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
+ Node combine_i = nm->mkNode(APPLY_UF, combine, i);
+ Node combine_n = nm->mkNode(APPLY_UF, combine, n);
+ Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
+ Node unionDisjoint_iMinusOne = nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
+ Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
+ Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
+ Node combine_0_equal = combine_0.eqNode(t);
+ Node combine_i_equal =
+ combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
+ Node unionDisjoint_0_equal =
+ unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
+ Node singleton = nm->mkBag(elementType, uf_i, one);
+
+ Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
+ nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
+ Node interval_i =
+ nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
+
+ Node body_i =
+ nm->mkNode(IMPLIES,
+ interval_i,
+ nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
+ Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
+ Node nonNegative = nm->mkNode(GEQ, n, zero);
+ Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
+ asserts.push_back(forAll_i);
+ asserts.push_back(combine_0_equal);
+ asserts.push_back(unionDisjoint_0_equal);
+ asserts.push_back(unionDisjoint_n_equal);
+ asserts.push_back(nonNegative);
+ return combine_n;
}
Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
return cardinality_n;
}
+Node BagReduction::reduceAggregateOperator(Node node)
+{
+ Assert(node.getKind() == TABLE_AGGREGATE);
+ NodeManager* nm = NodeManager::currentNM();
+ BoundVarManager* bvm = nm->getBoundVarManager();
+ Node function = node[0];
+ TypeNode elementType = function.getType().getArgTypes()[0];
+ Node initialValue = node[1];
+ Node A = node[2];
+ const std::vector<uint32_t>& indices =
+ node.getOperator().getConst<TableAggregateOp>().getIndices();
+
+ Node t1 = bvm->mkBoundVar<FirstIndexVarAttribute>(node, "t1", elementType);
+ Node t2 = bvm->mkBoundVar<SecondIndexVarAttribute>(node, "t2", elementType);
+ Node list = nm->mkNode(BOUND_VAR_LIST, t1, t2);
+ Node body = nm->mkConst(true);
+ for (uint32_t i : indices)
+ {
+ Node select1 = datatypes::TupleUtils::nthElementOfTuple(t1, i);
+ Node select2 = datatypes::TupleUtils::nthElementOfTuple(t2, i);
+ Node equal = select1.eqNode(select2);
+ body = body.andNode(equal);
+ }
+
+ Node lambda = nm->mkNode(LAMBDA, list, body);
+ Node partition = nm->mkNode(BAG_PARTITION, lambda, A);
+
+ Node bag = bvm->mkBoundVar<FirstIndexVarAttribute>(
+ partition, "bag", nm->mkBagType(elementType));
+ Node foldList = nm->mkNode(BOUND_VAR_LIST, bag);
+ Node foldBody = nm->mkNode(BAG_FOLD, function, initialValue, bag);
+
+ Node fold = nm->mkNode(LAMBDA, foldList, foldBody);
+ Node map = nm->mkNode(BAG_MAP, fold, partition);
+ return map;
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5::internal
/**
* class for bag reductions
*/
-class BagReduction : EnvObj
+class BagReduction
{
public:
- BagReduction(Env& env);
+ BagReduction();
~BagReduction();
/**
* combine: Int -> T2 is an uninterpreted function
* unionDisjoint: Int -> (Bag T1) is an uninterpreted function
*/
- Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+ static Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
/**
* @param node a term of the form (bag.card A) where A: (Bag T) is a bag
* cardinality: Int -> Int is an uninterpreted function
* unionDisjoint: Int -> (Bag T1) is an uninterpreted function
*/
- Node reduceCardOperator(Node node, std::vector<Node>& asserts);
-
- private:
+ static Node reduceCardOperator(Node node, std::vector<Node>& asserts);
+ /**
+ * @param node of the form ((_ table.aggr n1 ... nk) f initial A))
+ * @return reduction term that uses map, fold, and partition using
+ * tuple projection as the equivalence relation as follows:
+ * (bag.map
+ * (lambda ((B Table)) (bag.fold f initial B))
+ * (bag.partition
+ * (lambda ((t1 Tuple) (t2 Tuple)) ; equivalence relation
+ * (=
+ * ((_ tuple.project n1 ... nk) t1)
+ * ((_ tuple.project n1 ... nk) t2)))
+ * A))
+ */
+ static Node reduceAggregateOperator(Node node);
};
} // namespace bags
case kind::BAG_FILTER: checkFilter(n); break;
case kind::BAG_MAP: checkMap(n); break;
case kind::TABLE_PRODUCT: checkProduct(n); break;
+ case kind::TABLE_JOIN: checkJoin(n); break;
default: break;
}
it++;
}
}
+void BagSolver::checkJoin(Node n)
+{
+ Assert(n.getKind() == TABLE_JOIN);
+ const set<Node>& elementsA = d_state.getElements(n[0]);
+ const set<Node>& elementsB = d_state.getElements(n[1]);
+
+ for (const Node& e1 : elementsA)
+ {
+ for (const Node& e2 : elementsB)
+ {
+ InferInfo i = d_ig.joinUp(
+ n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
+ d_im.lemmaTheoryInference(&i);
+ }
+ }
+
+ std::set<Node> elements = d_state.getElements(n);
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.joinDown(n, d_state.getRepresentative(e));
+ d_im.lemmaTheoryInference(&i);
+ }
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5::internal
void checkFilter(Node n);
/** apply inference rules for product operator */
void checkProduct(Node n);
+ /** apply inference rules for join operator */
+ void checkJoin(Node n);
/** The solver state object */
SolverState& d_state;
}
else if (BagsUtils::areChildrenConstants(n))
{
- Node value = BagsUtils::evaluate(n);
+ Node value = BagsUtils::evaluate(d_rewriter, n);
response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
}
else
case BAG_FOLD: response = postRewriteFold(n); break;
case BAG_PARTITION: response = postRewritePartition(n); break;
case TABLE_PRODUCT: response = postRewriteProduct(n); break;
+ case TABLE_AGGREGATE: response = postRewriteAggregate(n); break;
default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
}
}
return BagsRewriteResponse(n, Rewrite::NONE);
}
+BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const
+{
+ Assert(n.getKind() == kind::TABLE_AGGREGATE);
+ if (n[1].isConst() && n[2].isConst())
+ {
+ Node ret = BagsUtils::evaluateTableAggregate(d_rewriter, n);
+ if (ret != n)
+ {
+ return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST);
+ }
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
{
Assert(n.getKind() == TABLE_PRODUCT);
*/
BagsRewriteResponse postRewriteFold(const TNode& n) const;
BagsRewriteResponse postRewritePartition(const TNode& n) const;
+ BagsRewriteResponse postRewriteAggregate(const TNode& n) const;
/**
* rewrites for n include:
* - (bag.product A (as bag.empty T2)) = (as bag.empty T)
#include "expr/emptybag.h"
#include "smt/logic_exception.h"
#include "table_project_op.h"
+#include "theory/bags/bag_reduction.h"
#include "theory/datatypes/tuple_utils.h"
#include "theory/rewriter.h"
#include "theory/sets/normal_form.h"
return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
}
-Node BagsUtils::evaluate(TNode n)
+Node BagsUtils::evaluate(Rewriter* rewriter, TNode n)
{
Assert(areChildrenConstants(n));
if (n.isConst())
case BAG_FILTER: return evaluateBagFilter(n);
case BAG_FOLD: return evaluateBagFold(n);
case TABLE_PRODUCT: return evaluateProduct(n);
+ case TABLE_JOIN: return evaluateJoin(rewriter, n);
case TABLE_PROJECT: return evaluateTableProject(n);
default: break;
}
return ret;
}
+Node BagsUtils::evaluateTableAggregate(Rewriter* rewriter, TNode n)
+{
+ Assert(n.getKind() == TABLE_AGGREGATE);
+ if (!(n[1].isConst() && n[2].isConst()))
+ {
+ // we can't proceed further.
+ return n;
+ }
+
+ Node reduction = BagReduction::reduceAggregateOperator(n);
+ return reduction;
+}
+
Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
{
- Assert(n.getKind() == TABLE_PRODUCT);
+ Assert(n.getKind() == TABLE_PRODUCT || n.getKind() == TABLE_JOIN);
Node A = n[0];
Node B = n[1];
TypeNode typeA = A.getType().getBagElementType();
return ret;
}
+Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n)
+{
+ Assert(n.getKind() == TABLE_JOIN);
+
+ Node A = n[0];
+ Node B = n[1];
+ auto [aIndices, bIndices] = splitTableJoinIndices(n);
+
+ std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
+ std::map<Node, Rational> elementsB = BagsUtils::getBagElements(B);
+
+ std::map<Node, Rational> elements;
+
+ for (const auto& [a, countA] : elementsA)
+ {
+ Node aProjection = TupleUtils::getTupleProjection(aIndices, a);
+ aProjection = rewriter->rewrite(aProjection);
+ Assert(aProjection.isConst());
+ for (const auto& [b, countB] : elementsB)
+ {
+ Node bProjection = TupleUtils::getTupleProjection(bIndices, b);
+ bProjection = rewriter->rewrite(bProjection);
+ Assert(bProjection.isConst());
+ if (aProjection == bProjection)
+ {
+ Node element = constructProductTuple(n, a, b);
+ elements[element] = countA * countB;
+ }
+ }
+ }
+
+ Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
+ return ret;
+}
+
Node BagsUtils::evaluateTableProject(TNode n)
{
Assert(n.getKind() == TABLE_PROJECT);
return ret;
}
+std::pair<std::vector<uint32_t>, std::vector<uint32_t>>
+BagsUtils::splitTableJoinIndices(Node n)
+{
+ Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator()
+ && n.getOperator().getKind() == kind::TABLE_JOIN_OP);
+ TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+ const std::vector<uint32_t>& indices = op.getIndices();
+ size_t joinSize = indices.size() / 2;
+ std::vector<uint32_t> indices1(joinSize), indices2(joinSize);
+
+ for (size_t i = 0, index = 0; i < joinSize; i += 2, ++index)
+ {
+ indices1[index] = indices[i];
+ indices2[index] = indices[i + 1];
+ }
+ return std::make_pair(indices1, indices2);
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5::internal
* evaluate the node n to a constant value.
* As a precondition, children of n should be constants.
*/
- static Node evaluate(TNode n);
+ static Node evaluate(Rewriter* rewriter, TNode n);
/**
* get the elements along with their multiplicities in a given bag
* @param n has the form (bag.partition r A) where A is a constant bag
* @return a partition of A based on the equivalence relation r
*/
- static Node evaluateBagPartition(Rewriter *rewriter, TNode n);
+ static Node evaluateBagPartition(Rewriter* rewriter, TNode n);
+
+ /**
+ * @param n has the form ((_ table.aggr n1 ... n_k) f initial A)
+ * where initial and A are constants
+ * @return the aggregation result.
+ */
+ static Node evaluateTableAggregate(Rewriter* rewriter, TNode n);
/**
* @param n has the form (bag.filter p A) where A is a constant bag
*/
static Node evaluateProduct(TNode n);
+ /**
+ * @param n of the form ((_ table.join (m_1 n_1 ... m_k n_k) ) A B) where
+ * A, B are constants
+ * @return the evaluation of inner joining tables A B on columns (m_1, n_1,
+ * ..., m_k, n_k)
+ */
+ static Node evaluateJoin(Rewriter* rewriter, TNode n);
+
/**
* @param n of the form ((_ table.project i_1 ... i_n) A) where A is a
* constant
*/
static Node evaluateTableProject(TNode n);
+ /**
+ * @param n has the form ((_ table.join m1 n1 ... mk nk) A B)) where A, B are
+ * tables and m1 n1 ... mk nk are indices
+ * @return the pair <[m1 ... mk], [n1 ... nk]>
+ */
+ static std::pair<std::vector<uint32_t>, std::vector<uint32_t>>
+ splitTableJoinIndices(Node n);
+
private:
/**
* a high order helper function that return a constant bag that is the result
namespace bags {
CardSolver::CardSolver(Env& env, SolverState& s, InferenceManager& im)
- : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im), d_bagReduction(env)
+ : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im)
{
d_nm = NodeManager::currentNM();
d_zero = d_nm->mkConstInt(Rational(0));
Node card = d_nm->mkNode(BAG_CARD, parent);
std::vector<Node> asserts;
- Node reduced = d_bagReduction.reduceCardOperator(card, asserts);
+ Node reduced = BagReduction::reduceCardOperator(card, asserts);
asserts.push_back(card.eqNode(reduced));
InferInfo inferInfo(&d_im, InferenceId::BAGS_CARD);
inferInfo.d_premises.push_back(premise);
InferenceManager& d_im;
NodeManager* d_nm;
- /** bag reduction */
- BagReduction d_bagReduction;
-
/**
* A map from bag representatives to sets of bag representatives with the
* invariant that each key is the disjoint union of each set in the value.
#include "theory/bags/bags_utils.h"
#include "theory/bags/inference_manager.h"
#include "theory/bags/solver_state.h"
+#include "theory/bags/table_project_op.h"
#include "theory/datatypes/tuple_utils.h"
#include "theory/quantifiers/fmf/bounded_integers.h"
#include "theory/uf/equality_engine.h"
Node countA = getMultiplicityTerm(e1, A);
Node countB = getMultiplicityTerm(e2, B);
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one));
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one));
+
Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
Node count = getMultiplicityTerm(tuple, skolem);
Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
Node count = getMultiplicityTerm(e, skolem);
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one));
+
+ Node multiply = d_nm->mkNode(MULT, countA, countB);
+ inferInfo.d_conclusion = count.eqNode(multiply);
+
+ return inferInfo;
+}
+
+InferInfo InferenceGenerator::joinUp(Node n, Node e1, Node e2)
+{
+ Assert(n.getKind() == TABLE_JOIN);
+ Node A = n[0];
+ Node B = n[1];
+ Node tuple = BagsUtils::constructProductTuple(n, e1, e2);
+
+ std::vector<Node> aElements = TupleUtils::getTupleElements(e1);
+ std::vector<Node> bElements = TupleUtils::getTupleElements(e2);
+ const std::vector<uint32_t>& indices =
+ n.getOperator().getConst<TableJoinOp>().getIndices();
+
+ InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_UP);
+
+ for (size_t i = 0; i < indices.size(); i += 2)
+ {
+ Node x = aElements[indices[i]];
+ Node y = bElements[indices[i + 1]];
+ Node equal = x.eqNode(y);
+ inferInfo.d_premises.push_back(equal);
+ }
+
+ Node countA = getMultiplicityTerm(e1, A);
+ Node countB = getMultiplicityTerm(e2, B);
+
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one));
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one));
+
+ Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+ Node count = getMultiplicityTerm(tuple, skolem);
Node multiply = d_nm->mkNode(MULT, countA, countB);
inferInfo.d_conclusion = count.eqNode(multiply);
+ return inferInfo;
+}
+
+InferInfo InferenceGenerator::joinDown(Node n, Node e)
+{
+ Assert(n.getKind() == TABLE_JOIN);
+ Assert(e.getType().isSubtypeOf(n.getType().getBagElementType()));
+
+ Node A = n[0];
+ Node B = n[1];
+
+ TypeNode tupleBType = B.getType().getBagElementType();
+ TypeNode tupleAType = A.getType().getBagElementType();
+ size_t tupleALength = tupleAType.getTupleLength();
+ size_t productTupleLength = n.getType().getBagElementType().getTupleLength();
+
+ std::vector<Node> elements = TupleUtils::getTupleElements(e);
+ Node a = TupleUtils::constructTupleFromElements(
+ tupleAType, elements, 0, tupleALength - 1);
+ Node b = TupleUtils::constructTupleFromElements(
+ tupleBType, elements, tupleALength, productTupleLength - 1);
+
+ InferInfo inferInfo(d_im, InferenceId::TABLES_JOIN_DOWN);
+
+ Node countA = getMultiplicityTerm(a, A);
+ Node countB = getMultiplicityTerm(b, B);
+
+ Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+ Node count = getMultiplicityTerm(e, skolem);
+ inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one));
+
+ Node multiply = d_nm->mkNode(MULT, countA, countB);
+ Node multiplicityConstraint = count.eqNode(multiply);
+ const std::vector<uint32_t>& indices =
+ n.getOperator().getConst<TableJoinOp>().getIndices();
+ Node joinConstraints = d_true;
+ for (size_t i = 0; i < indices.size(); i += 2)
+ {
+ Node x = elements[indices[i]];
+ Node y = elements[tupleALength + indices[i + 1]];
+ Node equal = x.eqNode(y);
+ joinConstraints = joinConstraints.andNode(equal);
+ }
+ inferInfo.d_conclusion = joinConstraints.andNode(multiplicityConstraint);
return inferInfo;
}
InferInfo filterUpwards(Node n, Node e);
/**
- * @param n is a (table.product A B) where A, B are bags of tuples
+ * @param n is a (table.product A B) where A, B are tables
* @param e1 an element of the form (tuple a1 ... am)
* @param e2 an element of the form (tuple b1 ... bn)
* @return an inference that represents the following
- * (=
- * (bag.count (tuple a1 ... am b1 ... bn) skolem)
- * (* (bag.count e1 A) (bag.count e2 B)))
+ * (=> (and (bag.member e1 A) (bag.member e2 B))
+ * (=
+ * (bag.count (tuple a1 ... am b1 ... bn) skolem)
+ * (* (bag.count e1 A) (bag.count e2 B))))
* where skolem is a variable equals (bag.product A B)
*/
InferInfo productUp(Node n, Node e1, Node e2);
/**
- * @param n is a (table.product A B) where A, B are bags of tuples
+ * @param n is a (table.product A B) where A, B are tables
* @param e an element of the form (tuple a1 ... am b1 ... bn)
* @return an inference that represents the following
- * (=
- * (bag.count (tuple a1 ... am b1 ... bn) skolem)
- * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B)))
+ * (=> (bag.member e skolem)
+ * (=
+ * (bag.count (tuple a1 ... am b1 ... bn) skolem)
+ * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))))
* where skolem is a variable equals (bag.product A B)
*/
InferInfo productDown(Node n, Node e);
+ /**
+ * @param n is a ((_ table.join m1 n1 ... mk nk) A B) where A, B are tables
+ * @param e1 an element of the form (tuple a1 ... am)
+ * @param e2 an element of the form (tuple b1 ... bn)
+ * @return an inference that represents the following
+ * (=> (and
+ * (bag.member e1 A)
+ * (bag.member e2 B)
+ * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk}))
+ * (=
+ * (bag.count (tuple a1 ... am b1 ... bn) skolem)
+ * (* (bag.count e1 A) (bag.count e2 B))))
+ * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B)
+ */
+ InferInfo joinUp(Node n, Node e1, Node e2);
+
+ /**
+ * @param n is a (table.product A B) where A, B are tables
+ * @param e an element of the form (tuple a1 ... am b1 ... bn)
+ * @return an inference that represents the following
+ * (=> (bag.member e skolem)
+ * (and
+ * (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk})
+ * (=
+ * (bag.count (tuple a1 ... am b1 ... bn) skolem)
+ * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))))
+ * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B)
+ */
+ InferInfo joinDown(Node n, Node e);
+
/**
* @param element of type T
* @param bag of type (bag T)
parameterized TABLE_PROJECT TABLE_PROJECT_OP 1 "table projection"
+# table.aggregate operator
+constant TABLE_AGGREGATE_OP \
+ class \
+ TableAggregateOp \
+ ::cvc5::internal::TableAggregateOpHashFunction \
+ "theory/bags/table_project_op.h" \
+ "operator for TABLE_AGGREGATE; payload is an instance of the cvc5::internal::TableAggregateOp class"
+
+parameterized TABLE_AGGREGATE TABLE_AGGREGATE_OP 3 "table aggregate"
+
+# table.join operator
+constant TABLE_JOIN_OP \
+ class \
+ TableJoinOp \
+ ::cvc5::internal::TableJoinOpHashFunction \
+ "theory/bags/table_project_op.h" \
+ "operator for TABLE_JOIN; payload is an instance of the cvc5::internal::TableJoinOp class"
+
+parameterized TABLE_JOIN TABLE_JOIN_OP 2 "table join"
+
typerule TABLE_PRODUCT ::cvc5::internal::theory::bags::TableProductTypeRule
typerule TABLE_PROJECT_OP "SimpleTypeRule<RBuiltinOperator>"
typerule TABLE_PROJECT ::cvc5::internal::theory::bags::TableProjectTypeRule
+typerule TABLE_AGGREGATE_OP "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_AGGREGATE ::cvc5::internal::theory::bags::TableAggregateTypeRule
+typerule TABLE_JOIN_OP "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_JOIN ::cvc5::internal::theory::bags::TableJoinTypeRule
endtheory
switch (r)
{
case Rewrite::NONE: return "NONE";
+ case Rewrite::AGGREGATE_CONST: return "AGGREGATE_CONST";
case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE";
case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT";
case Rewrite::CARD_BAG_MAKE: return "CARD_BAG_MAKE";
enum class Rewrite : uint32_t
{
NONE, // no rewrite happened
+ AGGREGATE_CONST,
BAG_MAKE_COUNT_NEGATIVE,
CARD_DISJOINT,
CARD_BAG_MAKE,
{
}
+TableAggregateOp::TableAggregateOp(std::vector<uint32_t> indices)
+ : ProjectOp(std::move(indices))
+{
+}
+
+TableJoinOp::TableJoinOp(std::vector<uint32_t> indices)
+ : ProjectOp(std::move(indices))
+{
+}
+
} // namespace cvc5::internal
}; /* class TableProjectOp */
/**
- * Hash function for the TupleProjectOpHashFunction objects.
+ * Hash function for the TableProjectOpHashFunction objects.
*/
struct TableProjectOpHashFunction : public ProjectOpHashFunction
{
-}; /* struct TupleProjectOpHashFunction */
+}; /* struct TableProjectOpHashFunction */
+
+class TableAggregateOp : public ProjectOp
+{
+ public:
+ explicit TableAggregateOp(std::vector<uint32_t> indices);
+ TableAggregateOp(const TableAggregateOp& op) = default;
+}; /* class TableAggregateOp */
+
+/**
+ * Hash function for the TableAggregateOpHashFunction objects.
+ */
+struct TableAggregateOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TableAggregateOpHashFunction */
+
+
+class TableJoinOp : public ProjectOp
+{
+ public:
+ explicit TableJoinOp(std::vector<uint32_t> indices);
+ TableJoinOp(const TableJoinOp& op) = default;
+}; /* class TableJoinOp */
+
+/**
+ * Hash function for the TableJoinOpHashFunction objects.
+ */
+struct TableJoinOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TableJoinOpHashFunction */
+
} // namespace cvc5::internal
d_rewriter(env.getRewriter(), &d_statistics.d_rewrites),
d_termReg(env, d_state, d_im),
d_solver(env, d_state, d_im, d_termReg),
- d_cardSolver(env, d_state, d_im),
- d_bagReduction(env)
+ d_cardSolver(env, d_state, d_im)
{
// use the official theory state and inference manager objects
d_theoryState = &d_state;
d_equalityEngine->addFunctionKind(BAG_PARTITION);
d_equalityEngine->addFunctionKind(TABLE_PRODUCT);
d_equalityEngine->addFunctionKind(TABLE_PROJECT);
+ d_equalityEngine->addFunctionKind(TABLE_AGGREGATE);
+ d_equalityEngine->addFunctionKind(TABLE_JOIN);
}
TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
case kind::BAG_FOLD:
{
std::vector<Node> asserts;
- Node ret = d_bagReduction.reduceFoldOperator(atom, asserts);
+ Node ret = BagReduction::reduceFoldOperator(atom, asserts);
NodeManager* nm = NodeManager::currentNM();
Node andNode = nm->mkNode(AND, asserts);
d_im.lemma(andNode, InferenceId::BAGS_FOLD);
<< andNode << std::endl;
return TrustNode::mkTrustRewrite(atom, ret, nullptr);
}
+ case kind::TABLE_AGGREGATE:
+ {
+ Node ret = BagReduction::reduceAggregateOperator(atom);
+ Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl;
+ return TrustNode::mkTrustRewrite(atom, ret, nullptr);
+ }
default: return TrustNode::null();
}
}
/** the main solver for bags */
CardSolver d_cardSolver;
- /** bag reduction */
- BagReduction d_bagReduction;
-
/** The representation of the strategy */
Strategy d_strat;
namespace theory {
namespace bags {
+using namespace datatypes;
+
TypeNode BinaryOperatorTypeRule::computeType(NodeManager* nodeManager,
TNode n,
bool check)
throw TypeCheckingExceptionPrivate(n, ss.str());
}
- std::vector<TypeNode> productTupleTypes;
- std::vector<TypeNode> tupleATypes = elementAType.getTupleTypes();
- std::vector<TypeNode> tupleBTypes = elementBType.getTupleTypes();
-
- productTupleTypes.insert(
- productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end());
- productTupleTypes.insert(
- productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end());
-
- TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes);
+ TypeNode retTupleType =
+ TupleUtils::concatTupleTypes(elementAType, elementBType);
TypeNode retType = nodeManager->mkBagType(retTupleType);
return retType;
}
}
TypeNode tupleType = bagType.getBagElementType();
TypeNode retTupleType =
- datatypes::TupleUtils::getTupleProjectionType(indices, tupleType);
+ TupleUtils::getTupleProjectionType(indices, tupleType);
+ return nm->mkBagType(retTupleType);
+}
+
+TypeNode TableAggregateTypeRule::computeType(NodeManager* nm,
+ TNode n,
+ bool check)
+{
+ Assert(n.getKind() == kind::TABLE_AGGREGATE && n.hasOperator()
+ && n.getOperator().getKind() == kind::TABLE_AGGREGATE_OP);
+ TableAggregateOp op = n.getOperator().getConst<TableAggregateOp>();
+ const std::vector<uint32_t>& indices = op.getIndices();
+
+ TypeNode functionType = n[0].getType(check);
+ TypeNode initialValueType = n[1].getType(check);
+ TypeNode bagType = n[2].getType(check);
+
+ if (check)
+ {
+ if (!bagType.isBag())
+ {
+ std::stringstream ss;
+ ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+ << "' of type '" << bagType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+
+ TypeNode tupleType = bagType.getBagElementType();
+ if (!tupleType.isTuple())
+ {
+ std::stringstream ss;
+ ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+ << "' of type '" << bagType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+
+ TupleUtils::checkTypeIndices(n, tupleType, indices);
+
+ TypeNode elementType = bagType.getBagElementType();
+
+ if (!(functionType.isFunction()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " T T) as a first argument. "
+ << "Found a term of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ std::vector<TypeNode> argTypes = functionType.getArgTypes();
+ TypeNode rangeType = functionType.getRangeType();
+ if (!(argTypes.size() == 2 && argTypes[0] == elementType
+ && argTypes[1] == rangeType))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " T T). "
+ << "Found a function of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ if (rangeType != initialValueType)
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects an initial value of type "
+ << rangeType << ". Found a term of type '" << initialValueType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ return nm->mkBagType(functionType.getRangeType());
+}
+
+TypeNode TableJoinTypeRule::computeType(NodeManager* nm, TNode n, bool check)
+{
+ Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator()
+ && n.getOperator().getKind() == kind::TABLE_JOIN_OP);
+ TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+ const std::vector<uint32_t>& indices = op.getIndices();
+ Node A = n[0];
+ Node B = n[1];
+ TypeNode aType = A.getType();
+ TypeNode bType = B.getType();
+
+ if (check)
+ {
+ if (!(aType.isBag() && bType.isBag()))
+ {
+ std::stringstream ss;
+ ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '"
+ << n[1] << "' of types '" << aType << "', '" << bType
+ << "' respectively. ";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+
+ TypeNode aTupleType = aType.getBagElementType();
+ TypeNode bTupleType = bType.getBagElementType();
+ if (!(aTupleType.isTuple() && bTupleType.isTuple()))
+ {
+ std::stringstream ss;
+ ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '"
+ << n[1] << "' of types '" << aType << "', '" << bType
+ << "' respectively. ";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+
+ if (indices.size() % 2 != 0)
+ {
+ std::stringstream ss;
+ ss << "TABLE_JOIN operator expects even number of indices. Found "
+ << indices.size() << " in term " << n;
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ auto [aIndices, bIndices] = BagsUtils::splitTableJoinIndices(n);
+ TupleUtils::checkTypeIndices(n, aTupleType, aIndices);
+ TupleUtils::checkTypeIndices(n, bTupleType, bIndices);
+
+ // check the types of columns
+ std::vector<TypeNode> aTypes = aTupleType.getTupleTypes();
+ std::vector<TypeNode> bTypes = bTupleType.getTupleTypes();
+ for (uint32_t i = 0; i < aIndices.size(); i++)
+ {
+ if (aTypes[aIndices[i]] != bTypes[bIndices[i]])
+ {
+ std::stringstream ss;
+ ss << "TABLE_JOIN operator expects column " << aIndices[i]
+ << " in table " << n[0] << " to match column " << bIndices[i]
+ << " in table " << n[1] << ". But their types are "
+ << aTypes[aIndices[i]] << " and " << bTypes[bIndices[i]]
+ << "' respectively. ";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ }
+ TypeNode aTupleType = aType.getBagElementType();
+ TypeNode bTupleType = bType.getBagElementType();
+ TypeNode retTupleType = TupleUtils::concatTupleTypes(aTupleType, bTupleType);
return nm->mkBagType(retTupleType);
}
}; /* struct BagFoldTypeRule */
/**
- * Type rule for (bag.partition r A) to make sure r is a binary operation of type
+ * Type rule for (bag.partition r A) to make sure r is a binary operation of
+ * type
* (-> T1 T1 Bool), and A is a bag of type (Bag T1)
*/
struct BagPartitionTypeRule
static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
}; /* struct BagFoldTypeRule */
+/**
+ * Table aggregate operator is indexed by a list of indices (n_1, ..., n_k).
+ * It ensures that it has 3 arguments:
+ * - A combining function of type (-> (Tuple T_1 ... T_j) T T)
+ * - Initial value of type T
+ * - A table of type (Table T_1 ... T_j) where 0 <= n_1, ..., n_k < j
+ * the returned type is (Bag T).
+ */
+struct TableAggregateTypeRule
+{
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct TableAggregateTypeRule */
+
+/**
+ * Table join operator is indexed by a list of indices (m_1, m_k, n_1, ...,
+ * n_k). It ensures that it has 2 arguments:
+ * - A table of type (Table X_1 ... X_i)
+ * - A table of type (Table Y_1 ... Y_j)
+ * such that indices has constraints 0 <= m_1, ..., mk, n_1, ..., n_k <=
+ * min(i,j) and types has constraints X_{m_1} = Y_{n_1}, ..., X_{m_k} = Y_{n_k}.
+ * The returned type is (Table X_1 ... X_i Y_1 ... Y_j)
+ */
+struct TableJoinTypeRule
+{
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct TableJoinTypeRule */
+
struct BagsProperties
{
static Cardinality computeCardinality(TypeNode type);
#include "tuple_utils.h"
+#include <sstream>
+
#include "expr/dtype.h"
#include "expr/dtype_cons.h"
namespace theory {
namespace datatypes {
+void TupleUtils::checkTypeIndices(Node n,
+ TypeNode tupleType,
+ const std::vector<uint32_t> indices)
+{
+ // make sure all indices are less than the size of the tuple
+ DType dType = tupleType.getDType();
+ DTypeConstructor constructor = dType[0];
+ size_t numArgs = constructor.getNumArgs();
+ for (uint32_t index : indices)
+ {
+ std::stringstream ss;
+ if (index >= numArgs)
+ {
+ ss << "Index " << index << " in term " << n << " is > " << (numArgs - 1)
+ << " the maximum value ";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+}
+
+TypeNode TupleUtils::concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2)
+{
+ std::vector<TypeNode> concatTupleTypes;
+ std::vector<TypeNode> tuple1Types = tupleType1.getTupleTypes();
+ std::vector<TypeNode> tuple2Types = tupleType2.getTupleTypes();
+
+ concatTupleTypes.insert(
+ concatTupleTypes.end(), tuple1Types.begin(), tuple1Types.end());
+ concatTupleTypes.insert(
+ concatTupleTypes.end(), tuple2Types.begin(), tuple2Types.end());
+ NodeManager* nm = NodeManager::currentNM();
+ TypeNode ret = nm->mkTupleType(concatTupleTypes);
+ return ret;
+}
+
Node TupleUtils::nthElementOfTuple(Node tuple, int n_th)
{
if (tuple.getKind() == APPLY_CONSTRUCTOR)
class TupleUtils
{
public:
+ /**
+ *
+ * @param n a node to print in the message if TypeCheckingExceptionPrivate
+ * exception is thrown
+ * @param tupleType the type of the tuple
+ * @param indices a list of indices for projection
+ * @throw an exception if one of the indices in node n is greater than the
+ * expected tuple's length
+ */
+ static void checkTypeIndices(Node n,
+ TypeNode tupleType,
+ const std::vector<uint32_t> indices);
+ /**
+ * @param tupleType1 tuple type
+ * @param tupleType2 tuple type
+ * @return the type of concatenation of tupleType1, tupleType2
+ */
+ static TypeNode concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2);
+
/**
* @param tuple a node of tuple type
* @param n_th the index of the element to be extracted, and must satisfy the
case InferenceId::BAGS_CARD_EMPTY: return "BAGS_CARD_EMPTY";
case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP";
case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN";
+ case InferenceId::TABLES_JOIN_UP: return "TABLES_JOIN_UP";
+ case InferenceId::TABLES_JOIN_DOWN: return "TABLES_JOIN_DOWN";
case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
BAGS_CARD_EMPTY,
TABLES_PRODUCT_UP,
TABLES_PRODUCT_DOWN,
+ TABLES_JOIN_UP,
+ TABLES_JOIN_DOWN,
// ---------------------------------- end bags theory
// ---------------------------------- bitvector theory
regress1/bags/proj-issue497.smt2
regress1/bags/subbag1.smt2
regress1/bags/subbag2.smt2
+ regress1/bags/table_aggregate1.smt2
+ regress1/bags/table_join1.smt2
+ regress1/bags/table_join2.smt2
+ regress1/bags/table_join3.smt2
regress1/bags/table_project1.smt2
regress1/bags/union_disjoint.smt2
regress1/bags/union_max1.smt2
--- /dev/null
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+(define-fun sumByCategory ((x (Tuple String String Int)) (y (Tuple String Int))) (Tuple String Int)
+ (tuple
+ ((_ tuple.select 0) x)
+ (+ ((_ tuple.select 2) x) ((_ tuple.select 1) y))))
+
+(declare-fun categorySales () (Bag (Tuple String Int)))
+
+;(define-fun categorySales () (Bag (Tuple String Int))
+; (bag.union_disjoint
+; (bag (tuple "Hardware" 10) 1)
+; (bag (tuple "Software" 6) 1)))
+
+(assert
+ (= categorySales
+ ((_ table.aggr 0)
+ sumByCategory
+ (tuple "" 0)
+ (bag.union_disjoint
+ (bag (tuple "Software" "win" 1) 2)
+ (bag (tuple "Software" "mac" 4) 1)
+ (bag (tuple "Hardware" "cpu" 2) 2)
+ (bag (tuple "Hardware" "gpu" 3) 2)))))
+
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+(assert
+ (= Departments
+ (bag.union_disjoint
+ (bag (tuple 1 "Computer") 1)
+ (bag (tuple 2 "Engineering") 1))))
+
+(assert
+ (= Students
+ (bag.union_disjoint
+ (bag (tuple 1 "A" 1) 1)
+ (bag (tuple 2 "B" 1) 1)
+ (bag (tuple 3 "C" 2) 1))))
+
+;(define-fun DepartmentStudents () (Bag (Tuple Int String Int String Int))
+; (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1)
+; (bag (tuple 1 "Computer" 2 "B" 1) 1)
+; (bag (tuple 2 "Engineering" 3 "C" 2) 1)))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+;(define-fun Departments () (Bag (Tuple Int String))
+; (bag.union_disjoint
+; (bag (tuple 1 "Computer") 1)
+; (bag (tuple 2 "Engineering") 1)))
+
+;(define-fun Students () (Bag (Tuple Int String Int))
+; (bag.union_disjoint
+; (bag (tuple 1 "A" 1) 1)
+; (bag (tuple 2 "B" 1) 1)
+; (bag (tuple 3 "C" 2) 1)))
+
+(assert
+ (= DepartmentStudents
+ (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1)
+ (bag (tuple 1 "Computer" 2 "B" 1) 1)
+ (bag (tuple 2 "Engineering" 3 "C" 2) 1))))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+(declare-fun d1 () (Tuple Int String))
+(declare-fun d2 () (Tuple Int String))
+(assert (distinct d1 d2))
+
+(declare-fun s1 () (Tuple Int String Int))
+(declare-fun s2 () (Tuple Int String Int))
+(assert (distinct s1 s2))
+
+(assert
+ (distinct DepartmentStudents (as bag.empty (Table Int String Int String Int))))
+
+(assert (bag.member d1 Departments))
+(assert (bag.member d2 Departments))
+(assert (bag.member s1 Students))
+(assert (bag.member s2 Students))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
return elements;
}
+ Node rewrite(Node n) { return d_rewriter->postRewrite(n).d_node; }
+
std::unique_ptr<BagsRewriter> d_rewriter;
};
Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
// empty bags are in normal form
ASSERT_TRUE(emptybag.isConst());
- Node n = BagsUtils::evaluate(emptybag);
+ Node n = rewrite(emptybag);
ASSERT_EQ(emptybag, n);
}
ASSERT_FALSE(negative.isConst());
ASSERT_FALSE(zero.isConst());
- ASSERT_EQ(emptybag, BagsUtils::evaluate(negative));
- ASSERT_EQ(emptybag, BagsUtils::evaluate(zero));
- ASSERT_EQ(positive, BagsUtils::evaluate(positive));
+ ASSERT_EQ(emptybag, rewrite(negative));
+ ASSERT_EQ(emptybag, rewrite(zero));
+ ASSERT_EQ(positive, rewrite(positive));
}
TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty);
Node output1 = zero;
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5);
Node output2 = zero;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4);
Node output3 = four;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY);
Node output4 = four;
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5);
Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ);
Node output5 = zero;
- ASSERT_EQ(output4, BagsUtils::evaluate(input4));
+ ASSERT_EQ(output4, rewrite(input4));
}
TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag);
Node output1 = emptybag;
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4);
Node output2 = x_1;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag);
Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, BagsUtils::evaluate(input));
+ ASSERT_EQ(output, rewrite(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B);
// unionDisjointAB is already in a normal form
ASSERT_TRUE(unionDisjointAB.isConst());
- ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB));
+ ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointAB));
Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A);
// unionDisjointAB is the normal form of unionDisjointBA
ASSERT_FALSE(unionDisjointBA.isConst());
- ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA));
+ ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointBA));
Node unionDisjointAB_C =
d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C);
// unionDisjointA_BC is the normal form of unionDisjointAB_C
ASSERT_FALSE(unionDisjointAB_C.isConst());
ASSERT_TRUE(unionDisjointA_BC.isConst());
- ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C));
+ ASSERT_EQ(unionDisjointA_BC, rewrite(unionDisjointAB_C));
Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A);
Node AA = d_nodeManager->mkBag(d_nodeManager->stringType(),
d_nodeManager->mkConstInt(Rational(4)));
ASSERT_FALSE(unionDisjointAA.isConst());
ASSERT_TRUE(AA.isConst());
- ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA));
+ ASSERT_EQ(AA, rewrite(unionDisjointAA));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, BagsUtils::evaluate(input));
+ ASSERT_EQ(output, rewrite(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
Node output = x_3;
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, BagsUtils::evaluate(input));
+ ASSERT_EQ(output, rewrite(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2);
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, BagsUtils::evaluate(input));
+ ASSERT_EQ(output, rewrite(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
Node output = z_2;
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, BagsUtils::evaluate(input));
+ ASSERT_EQ(output, rewrite(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
Node input1 = d_nodeManager->mkNode(BAG_CARD, empty);
Node output1 = d_nodeManager->mkConstInt(Rational(0));
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4);
Node output2 = d_nodeManager->mkConstInt(Rational(4));
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1);
Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint);
Node output3 = d_nodeManager->mkConstInt(Rational(5));
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty);
Node output1 = falseNode;
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1);
Node output2 = trueNode;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4);
Node output3 = falseNode;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint);
Node output4 = falseNode;
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset);
Node output1 = emptybag;
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton);
Node output2 = x_1;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
// for normal sets, the first node is the largest, not smallest
Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet);
Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag);
Node output1 = emptyset;
- ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+ ASSERT_EQ(output1, rewrite(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4);
Node output2 = xSingleton;
- ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+ ASSERT_EQ(output2, rewrite(input2));
// for normal sets, the first node is the largest, not smallest
Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag);
Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
- ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+ ASSERT_EQ(output3, rewrite(input3));
}
} // namespace test
} // namespace cvc5::internal