Add operators table.aggr and table.join (#8681)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 5 May 2022 13:41:30 +0000 (08:41 -0500)
committerGitHub <noreply@github.com>
Thu, 5 May 2022 13:41:30 +0000 (13:41 +0000)
36 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_reduction.cpp
src/theory/bags/bag_reduction.h
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/bags/card_solver.cpp
src/theory/bags/card_solver.h
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/kinds
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/table_project_op.cpp
src/theory/bags/table_project_op.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/datatypes/tuple_utils.cpp
src/theory/datatypes/tuple_utils.h
src/theory/inference_id.cpp
src/theory/inference_id.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/table_aggregate1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_join1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_join2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_join3.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_normal_form_white.cpp

index 115ecddc988c35fb3c72fa158f38187b2c14e3f3..0c6167963c621c0454ec1854d305dc12a3f70315 100644 (file)
@@ -331,6 +331,8 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION),
         KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT),
         KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
+        KIND_ENUM(TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE),
+        KIND_ENUM(TABLE_JOIN, internal::Kind::TABLE_JOIN),
         /* Strings ---------------------------------------------------------- */
         KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT),
         KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP),
@@ -649,6 +651,10 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
         {internal::Kind::TABLE_PROJECT, TABLE_PROJECT},
         {internal::Kind::TABLE_PROJECT_OP, TABLE_PROJECT},
+        {internal::Kind::TABLE_AGGREGATE_OP, TABLE_AGGREGATE},
+        {internal::Kind::TABLE_AGGREGATE, TABLE_AGGREGATE},
+        {internal::Kind::TABLE_JOIN_OP, TABLE_JOIN},
+        {internal::Kind::TABLE_JOIN, TABLE_JOIN},
         /* Strings --------------------------------------------------------- */
         {internal::Kind::STRING_CONCAT, STRING_CONCAT},
         {internal::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
@@ -6217,6 +6223,12 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
     case TABLE_PROJECT:
       res = mkOpHelper(kind, internal::TableProjectOp(args));
       break;
+    case TABLE_AGGREGATE:
+      res = mkOpHelper(kind, internal::TableAggregateOp(args));
+      break;
+    case TABLE_JOIN:
+      res = mkOpHelper(kind, internal::TableJoinOp(args));
+      break;
     default:
       if (nargs == 0)
       {
index 8ee2f378c0996f5a141fd012d9db48f18480de73..a7fa5644c658e381dc37338447198baaad89f123 100644 (file)
@@ -3726,10 +3726,10 @@ enum Kind : int32_t
    * Table projection operator extends tuple projection operator to tables.
    *
    * - Arity: ``1``
-   *   - ``1:`` Term of tuple Sort
+   *   - ``1:`` Term of table Sort
    *
    * - Indices: ``n``
-   *   - ``1..n:`` The table indices to project
+   *   - ``1..n:`` Indices of the projection
    *
    * - Create Term of this Kind with:
    *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
@@ -3738,6 +3738,58 @@ enum Kind : int32_t
    *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
    */
   TABLE_PROJECT,
+  /**
+   * Table aggregate operator has the form
+   * :math:`((\_ \; table.aggr \; n_1 ... n_k) f i A)`
+   * where :math:`n_1, ..., n_k` are natural numbers,
+   * :math:`f` is a function of type
+   * :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`,
+   * :math:`i` has the type :math:`T`,
+   * and :math`A` has type :math:`Table T_1 ... T_j`.
+   * The returned type is :math:`(Bag \; T)`.
+   *
+   * This operator aggregates elements in A that have the same tuple projection
+   * with indices n_1, ..., n_k using the combining function :math:`f`,
+   * and initial value :math:`i`.
+   *
+   * - Arity: ``3``
+   *   - ``1:`` Term of sort :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`
+   *   - ``2:`` Term of Sort :math:`T`
+   *   - ``3:`` Term of table sort :math:`Table T_1 ... T_j`
+   *
+   * - Indices: ``n``
+   *   - ``1..n:`` Indices of the projection
+   *
+   * - Create Term of this Kind with:
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * - Create Op of this kind with:
+   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   */
+  TABLE_AGGREGATE,
+  /**
+   * Table join operator has the form
+   *  :math:`((\_ \; table.join \; m_1 \; n_1 \; \dots \; m_k \; n_k) \; A \; B)`
+   *  where *  :math:`m_1 \; n_1 \; \dots \; m_k \; n_k` are natural numbers,
+   *  and A, B are tables.
+   *  This operator filters the product of two bags based on the equality of
+   *  projected tuples using indices :math:`m_1, \dots, m_k` in table A,
+   *  and indices :math:`n_1, \dots, n_k` in table B.
+   *
+   * - Arity: ``2``
+   *   - ``1:`` Term of table Sort
+   *   - ``2:`` Term of table Sort
+   *
+   * - Indices: ``n``
+   *   - ``1..n:``  Indices of the projection
+   *
+   * - Create Term of this Kind with:
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * - Create Op of this kind with:
+   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   */
+  TABLE_JOIN,
   /* Strings --------------------------------------------------------------- */
 
   /**
index 2a5dcf162323ef2b858e3587e9ef4e6cf251a6e2..239d36468fe35d92178890613716f55e23b15daa 100644 (file)
@@ -1413,6 +1413,18 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_PROJECT, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK TABLE_AGGREGATE_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
+  | LPAREN_TOK TABLE_JOIN_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_JOIN, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1561,6 +1573,20 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::TABLE_PROJECT;
         p.d_op = SOLVER->mkOp(cvc5::TABLE_PROJECT, numerals);
        }
+    | TABLE_AGGREGATE_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ table.aggr i_1 ... i_n) where
+        // i_1, ..., i_n are numerals
+        p.d_kind = cvc5::TABLE_AGGREGATE;
+        p.d_op = SOLVER->mkOp(cvc5::TABLE_AGGREGATE, numerals);
+      }
+    | TABLE_JOIN_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ table.join i_1 j_1 ... i_n j_n) where
+        // i_1, ..., i_n, j_1, ..., j_n are numerals
+        p.d_kind = cvc5::TABLE_JOIN;
+        p.d_op = SOLVER->mkOp(cvc5::TABLE_JOIN, numerals);
+      }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2170,6 +2196,8 @@ CHAR_TOK : { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_STRINGS) }?
 TUPLE_CONST_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple';
 TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple.project';
 TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project';
+TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr';
+TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index a4a16c214ad00ff700165585708a727196b3f351..dad36788ea4dcbaa56a3cd40ab71b6f53b6a7cb1 100644 (file)
@@ -1125,7 +1125,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
     Trace("parser") << "applyParseOp: return selector " << ret << std::endl;
     return ret;
   }
-  else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT)
+  else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
+           || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index 19c67f67283aee4eda01d35c3c627b6452ccad9d..3cda30e18190138d84f51bc1b48542dc2e1f172a 100644 (file)
@@ -806,6 +806,37 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::TABLE_AGGREGATE:
+  {
+    TableAggregateOp op = n.getOperator().getConst<TableAggregateOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (table.project function initial_value bag)
+      out << "table.aggr " << n[0] << " " << n[1] << " " << n[2] << ")";
+    }
+    else
+    {
+      // e.g.  ((_ table.aggr 0) function initial_value bag)
+      out << "(_ table.aggr" << op << ") " << n[0] << " " << n[1] << " " << n[2]
+          << ")";
+    }
+    return;
+  }
+  case kind::TABLE_JOIN:
+  {
+    TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (table.join A B)
+      out << "table.join " << n[0] << " " << n[1] << ")";
+    }
+    else
+    {
+      // e.g. ((_ table.project 0 1 2 3) A B)
+      out << "(_ table.join" << op << ") " << n[0] << " " << n[1] << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -978,7 +1009,7 @@ void Smt2Printer::toStream(std::ostream& out,
     }
   }
   stringstream parens;
-  
+
   for(size_t i = 0, c = 1; i < n.getNumChildren(); ) {
     if(toDepth != 0) {
       toStream(out, n[i], toDepth < 0 ? toDepth : toDepth - c, lbind);
@@ -1074,9 +1105,9 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::GEQ: return ">=";
   case kind::DIVISION:
   case kind::DIVISION_TOTAL: return "/";
-  case kind::INTS_DIVISION_TOTAL: 
+  case kind::INTS_DIVISION_TOTAL:
   case kind::INTS_DIVISION: return "div";
-  case kind::INTS_MODULUS_TOTAL: 
+  case kind::INTS_MODULUS_TOTAL:
   case kind::INTS_MODULUS: return "mod";
   case kind::ABS: return "abs";
   case kind::IS_INTEGER: return "is_int";
@@ -1186,6 +1217,8 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_PARTITION: return "bag.partition";
   case kind::TABLE_PRODUCT: return "table.product";
   case kind::TABLE_PROJECT: return "table.project";
+  case kind::TABLE_AGGREGATE: return "table.aggr";
+  case kind::TABLE_JOIN: return "table.join";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
index 43e1e7ee93021adbe7b253daf42d7430f27b1c0d..5e6d7193521e47dea1935b060c27c6e81164e4e8 100644 (file)
@@ -18,6 +18,8 @@
 #include "expr/bound_var_manager.h"
 #include "expr/emptybag.h"
 #include "expr/skolem_manager.h"
+#include "table_project_op.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/quantifiers/fmf/bounded_integers.h"
 #include "util/rational.h"
 
@@ -28,7 +30,7 @@ namespace cvc5::internal {
 namespace theory {
 namespace bags {
 
-BagReduction::BagReduction(Env& env) : EnvObj(env) {}
+BagReduction::BagReduction() {}
 
 BagReduction::~BagReduction() {}
 
@@ -58,74 +60,68 @@ typedef expr::Attribute<SecondIndexVarAttributeId, Node>
 Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
 {
   Assert(node.getKind() == BAG_FOLD);
-  if (d_env.getLogicInfo().isHigherOrder())
-  {
-    NodeManager* nm = NodeManager::currentNM();
-    SkolemManager* sm = nm->getSkolemManager();
-    Node f = node[0];
-    Node t = node[1];
-    Node A = node[2];
-    Node zero = nm->mkConstInt(Rational(0));
-    Node one = nm->mkConstInt(Rational(1));
-    // types
-    TypeNode bagType = A.getType();
-    TypeNode elementType = A.getType().getBagElementType();
-    TypeNode integerType = nm->integerType();
-    TypeNode ufType = nm->mkFunctionType(integerType, elementType);
-    TypeNode resultType = t.getType();
-    TypeNode combineType = nm->mkFunctionType(integerType, resultType);
-    TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
-    // skolem functions
-    Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
-    Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
-    Node unionDisjoint = sm->mkSkolemFunction(
-        SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
-    Node combine = sm->mkSkolemFunction(
-        SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
-
-    BoundVarManager* bvm = nm->getBoundVarManager();
-    Node i =
-        bvm->mkBoundVar<FirstIndexVarAttribute>(node, "i", nm->integerType());
-    Node iList = nm->mkNode(BOUND_VAR_LIST, i);
-    Node iMinusOne = nm->mkNode(SUB, i, one);
-    Node uf_i = nm->mkNode(APPLY_UF, uf, i);
-    Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
-    Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
-    Node combine_i = nm->mkNode(APPLY_UF, combine, i);
-    Node combine_n = nm->mkNode(APPLY_UF, combine, n);
-    Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
-    Node unionDisjoint_iMinusOne =
-        nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
-    Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
-    Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
-    Node combine_0_equal = combine_0.eqNode(t);
-    Node combine_i_equal =
-        combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
-    Node unionDisjoint_0_equal =
-        unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
-    Node singleton = nm->mkBag(elementType, uf_i, one);
-
-    Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
-        nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
-    Node interval_i =
-        nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
-
-    Node body_i =
-        nm->mkNode(IMPLIES,
-                   interval_i,
-                   nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
-    Node forAll_i =
-        quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
-    Node nonNegative = nm->mkNode(GEQ, n, zero);
-    Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
-    asserts.push_back(forAll_i);
-    asserts.push_back(combine_0_equal);
-    asserts.push_back(unionDisjoint_0_equal);
-    asserts.push_back(unionDisjoint_n_equal);
-    asserts.push_back(nonNegative);
-    return combine_n;
-  }
-  return Node::null();
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+  Node f = node[0];
+  Node t = node[1];
+  Node A = node[2];
+  Node zero = nm->mkConstInt(Rational(0));
+  Node one = nm->mkConstInt(Rational(1));
+  // types
+  TypeNode bagType = A.getType();
+  TypeNode elementType = A.getType().getBagElementType();
+  TypeNode integerType = nm->integerType();
+  TypeNode ufType = nm->mkFunctionType(integerType, elementType);
+  TypeNode resultType = t.getType();
+  TypeNode combineType = nm->mkFunctionType(integerType, resultType);
+  TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
+  // skolem functions
+  Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
+  Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
+  Node unionDisjoint = sm->mkSkolemFunction(
+      SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
+  Node combine = sm->mkSkolemFunction(
+      SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
+
+  BoundVarManager* bvm = nm->getBoundVarManager();
+  Node i =
+      bvm->mkBoundVar<FirstIndexVarAttribute>(node, "i", nm->integerType());
+  Node iList = nm->mkNode(BOUND_VAR_LIST, i);
+  Node iMinusOne = nm->mkNode(SUB, i, one);
+  Node uf_i = nm->mkNode(APPLY_UF, uf, i);
+  Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
+  Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
+  Node combine_i = nm->mkNode(APPLY_UF, combine, i);
+  Node combine_n = nm->mkNode(APPLY_UF, combine, n);
+  Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
+  Node unionDisjoint_iMinusOne = nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
+  Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
+  Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
+  Node combine_0_equal = combine_0.eqNode(t);
+  Node combine_i_equal =
+      combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
+  Node unionDisjoint_0_equal =
+      unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
+  Node singleton = nm->mkBag(elementType, uf_i, one);
+
+  Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
+      nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
+  Node interval_i =
+      nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
+
+  Node body_i =
+      nm->mkNode(IMPLIES,
+                 interval_i,
+                 nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
+  Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
+  Node nonNegative = nm->mkNode(GEQ, n, zero);
+  Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
+  asserts.push_back(forAll_i);
+  asserts.push_back(combine_0_equal);
+  asserts.push_back(unionDisjoint_0_equal);
+  asserts.push_back(unionDisjoint_n_equal);
+  asserts.push_back(nonNegative);
+  return combine_n;
 }
 
 Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
@@ -206,6 +202,43 @@ Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
   return cardinality_n;
 }
 
+Node BagReduction::reduceAggregateOperator(Node node)
+{
+  Assert(node.getKind() == TABLE_AGGREGATE);
+  NodeManager* nm = NodeManager::currentNM();
+  BoundVarManager* bvm = nm->getBoundVarManager();
+  Node function = node[0];
+  TypeNode elementType = function.getType().getArgTypes()[0];
+  Node initialValue = node[1];
+  Node A = node[2];
+  const std::vector<uint32_t>& indices =
+      node.getOperator().getConst<TableAggregateOp>().getIndices();
+
+  Node t1 = bvm->mkBoundVar<FirstIndexVarAttribute>(node, "t1", elementType);
+  Node t2 = bvm->mkBoundVar<SecondIndexVarAttribute>(node, "t2", elementType);
+  Node list = nm->mkNode(BOUND_VAR_LIST, t1, t2);
+  Node body = nm->mkConst(true);
+  for (uint32_t i : indices)
+  {
+    Node select1 = datatypes::TupleUtils::nthElementOfTuple(t1, i);
+    Node select2 = datatypes::TupleUtils::nthElementOfTuple(t2, i);
+    Node equal = select1.eqNode(select2);
+    body = body.andNode(equal);
+  }
+
+  Node lambda = nm->mkNode(LAMBDA, list, body);
+  Node partition = nm->mkNode(BAG_PARTITION, lambda, A);
+
+  Node bag = bvm->mkBoundVar<FirstIndexVarAttribute>(
+      partition, "bag", nm->mkBagType(elementType));
+  Node foldList = nm->mkNode(BOUND_VAR_LIST, bag);
+  Node foldBody = nm->mkNode(BAG_FOLD, function, initialValue, bag);
+
+  Node fold = nm->mkNode(LAMBDA, foldList, foldBody);
+  Node map = nm->mkNode(BAG_MAP, fold, partition);
+  return map;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index 55933b1dc87bad24036751f158ec3578fed89320..c3f49b0a443770aeb3321a2b3b704310c6bf1426 100644 (file)
@@ -29,10 +29,10 @@ namespace bags {
 /**
  * class for bag reductions
  */
-class BagReduction : EnvObj
+class BagReduction
 {
  public:
-  BagReduction(Env& env);
+  BagReduction();
   ~BagReduction();
 
   /**
@@ -64,7 +64,7 @@ class BagReduction : EnvObj
    * combine: Int -> T2 is an uninterpreted function
    * unionDisjoint: Int -> (Bag T1) is an uninterpreted function
    */
-  Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+  static Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
 
   /**
    * @param node a term of the form (bag.card A) where A: (Bag T) is a bag
@@ -95,9 +95,21 @@ class BagReduction : EnvObj
    *   cardinality: Int -> Int is an uninterpreted function
    *   unionDisjoint: Int -> (Bag T1) is an uninterpreted function
    */
-  Node reduceCardOperator(Node node, std::vector<Node>& asserts);
-
- private:
+  static Node reduceCardOperator(Node node, std::vector<Node>& asserts);
+  /**
+   * @param node of the form ((_ table.aggr n1 ... nk) f initial A))
+   * @return reduction term that uses map, fold, and partition using
+   * tuple projection as the equivalence relation as follows:
+   * (bag.map
+   *   (lambda ((B Table)) (bag.fold f initial B))
+   *   (bag.partition
+   *     (lambda ((t1 Tuple) (t2 Tuple)) ; equivalence relation
+   *             (=
+   *               ((_ tuple.project n1 ... nk) t1)
+   *               ((_ tuple.project n1 ... nk) t2)))
+   *     A))
+   */
+  static Node reduceAggregateOperator(Node node);
 };
 
 }  // namespace bags
index f7864d01ac7d2d84031e8fe68c38473b7813e0a5..5118df4621785b9672f4ff2a9de952d08c19e9e0 100644 (file)
@@ -79,6 +79,7 @@ void BagSolver::checkBasicOperations()
         case kind::BAG_FILTER: checkFilter(n); break;
         case kind::BAG_MAP: checkMap(n); break;
         case kind::TABLE_PRODUCT: checkProduct(n); break;
+        case kind::TABLE_JOIN: checkJoin(n); break;
         default: break;
       }
       it++;
@@ -335,6 +336,30 @@ void BagSolver::checkProduct(Node n)
   }
 }
 
+void BagSolver::checkJoin(Node n)
+{
+  Assert(n.getKind() == TABLE_JOIN);
+  const set<Node>& elementsA = d_state.getElements(n[0]);
+  const set<Node>& elementsB = d_state.getElements(n[1]);
+
+  for (const Node& e1 : elementsA)
+  {
+    for (const Node& e2 : elementsB)
+    {
+      InferInfo i = d_ig.joinUp(
+          n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
+      d_im.lemmaTheoryInference(&i);
+    }
+  }
+
+  std::set<Node> elements = d_state.getElements(n);
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.joinDown(n, d_state.getRepresentative(e));
+    d_im.lemmaTheoryInference(&i);
+  }
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index c0737ef2fbe1d01c1a77dfbc6785cec1c767db50..2d6e890df900bfac604dade7fe33465ca8267430 100644 (file)
@@ -100,6 +100,8 @@ class BagSolver : protected EnvObj
   void checkFilter(Node n);
   /** apply inference rules for product operator */
   void checkProduct(Node n);
+  /** apply inference rules for join operator */
+  void checkJoin(Node n);
 
   /** The solver state object */
   SolverState& d_state;
index 6b7e49a31d3250a0061ebbdf4c24e44bccd80836..e401551a76819ee5019131fb74f8f9399bd38511 100644 (file)
@@ -68,7 +68,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
   }
   else if (BagsUtils::areChildrenConstants(n))
   {
-    Node value = BagsUtils::evaluate(n);
+    Node value = BagsUtils::evaluate(d_rewriter, n);
     response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
   }
   else
@@ -95,6 +95,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_FOLD: response = postRewriteFold(n); break;
       case BAG_PARTITION: response = postRewritePartition(n); break;
       case TABLE_PRODUCT: response = postRewriteProduct(n); break;
+      case TABLE_AGGREGATE: response = postRewriteAggregate(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
   }
@@ -665,6 +666,21 @@ BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
 
+BagsRewriteResponse BagsRewriter::postRewriteAggregate(const TNode& n) const
+{
+  Assert(n.getKind() == kind::TABLE_AGGREGATE);
+  if (n[1].isConst() && n[2].isConst())
+  {
+    Node ret = BagsUtils::evaluateTableAggregate(d_rewriter, n);
+    if (ret != n)
+    {
+      return BagsRewriteResponse(ret, Rewrite::AGGREGATE_CONST);
+    }
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
 BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
 {
   Assert(n.getKind() == TABLE_PRODUCT);
index 3c08208a80edad56f6b7b62c3c9f999be9dc9f1f..711846143acc8e1ca9ca791d8915c4580d574224 100644 (file)
@@ -247,6 +247,7 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse postRewriteFold(const TNode& n) const;
   BagsRewriteResponse postRewritePartition(const TNode& n) const;
+  BagsRewriteResponse postRewriteAggregate(const TNode& n) const;
   /**
    *  rewrites for n include:
    *  - (bag.product A (as bag.empty T2)) = (as bag.empty T)
index fd5a98c25aebf6d5d02a31990ad04d8d47f644a2..7ee00c432cf5823cbf2d623a35208e8e520eada1 100644 (file)
@@ -19,6 +19,7 @@
 #include "expr/emptybag.h"
 #include "smt/logic_exception.h"
 #include "table_project_op.h"
+#include "theory/bags/bag_reduction.h"
 #include "theory/datatypes/tuple_utils.h"
 #include "theory/rewriter.h"
 #include "theory/sets/normal_form.h"
@@ -118,7 +119,7 @@ bool BagsUtils::areChildrenConstants(TNode n)
   return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
 }
 
-Node BagsUtils::evaluate(TNode n)
+Node BagsUtils::evaluate(Rewriter* rewriter, TNode n)
 {
   Assert(areChildrenConstants(n));
   if (n.isConst())
@@ -144,6 +145,7 @@ Node BagsUtils::evaluate(TNode n)
     case BAG_FILTER: return evaluateBagFilter(n);
     case BAG_FOLD: return evaluateBagFold(n);
     case TABLE_PRODUCT: return evaluateProduct(n);
+    case TABLE_JOIN: return evaluateJoin(rewriter, n);
     case TABLE_PROJECT: return evaluateTableProject(n);
     default: break;
   }
@@ -879,9 +881,22 @@ Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n)
   return ret;
 }
 
+Node BagsUtils::evaluateTableAggregate(Rewriter* rewriter, TNode n)
+{
+  Assert(n.getKind() == TABLE_AGGREGATE);
+  if (!(n[1].isConst() && n[2].isConst()))
+  {
+    // we can't proceed further.
+    return n;
+  }
+
+  Node reduction = BagReduction::reduceAggregateOperator(n);
+  return reduction;
+}
+
 Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
 {
-  Assert(n.getKind() == TABLE_PRODUCT);
+  Assert(n.getKind() == TABLE_PRODUCT || n.getKind() == TABLE_JOIN);
   Node A = n[0];
   Node B = n[1];
   TypeNode typeA = A.getType().getBagElementType();
@@ -925,6 +940,41 @@ Node BagsUtils::evaluateProduct(TNode n)
   return ret;
 }
 
+Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n)
+{
+  Assert(n.getKind() == TABLE_JOIN);
+
+  Node A = n[0];
+  Node B = n[1];
+  auto [aIndices, bIndices] = splitTableJoinIndices(n);
+
+  std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
+  std::map<Node, Rational> elementsB = BagsUtils::getBagElements(B);
+
+  std::map<Node, Rational> elements;
+
+  for (const auto& [a, countA] : elementsA)
+  {
+    Node aProjection = TupleUtils::getTupleProjection(aIndices, a);
+    aProjection = rewriter->rewrite(aProjection);
+    Assert(aProjection.isConst());
+    for (const auto& [b, countB] : elementsB)
+    {
+      Node bProjection = TupleUtils::getTupleProjection(bIndices, b);
+      bProjection = rewriter->rewrite(bProjection);
+      Assert(bProjection.isConst());
+      if (aProjection == bProjection)
+      {
+        Node element = constructProductTuple(n, a, b);
+        elements[element] = countA * countB;
+      }
+    }
+  }
+
+  Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
+  return ret;
+}
+
 Node BagsUtils::evaluateTableProject(TNode n)
 {
   Assert(n.getKind() == TABLE_PROJECT);
@@ -955,6 +1005,24 @@ Node BagsUtils::evaluateTableProject(TNode n)
   return ret;
 }
 
+std::pair<std::vector<uint32_t>, std::vector<uint32_t>>
+BagsUtils::splitTableJoinIndices(Node n)
+{
+  Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator()
+         && n.getOperator().getKind() == kind::TABLE_JOIN_OP);
+  TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+  size_t joinSize = indices.size() / 2;
+  std::vector<uint32_t> indices1(joinSize), indices2(joinSize);
+
+  for (size_t i = 0, index = 0; i < joinSize; i += 2, ++index)
+  {
+    indices1[index] = indices[i];
+    indices2[index] = indices[i + 1];
+  }
+  return std::make_pair(indices1, indices2);
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index 21de8e959e53dee35d039c9eaf9fbb2318e1f32e..23f21371b75d14a8656405de9ae5dfb84b725baf 100644 (file)
@@ -54,7 +54,7 @@ class BagsUtils
    * evaluate the node n to a constant value.
    * As a precondition, children of n should be constants.
    */
-  static Node evaluate(TNode n);
+  static Node evaluate(Rewriter* rewriter, TNode n);
 
   /**
    * get the elements along with their multiplicities in a given bag
@@ -94,7 +94,14 @@ class BagsUtils
    * @param n has the form (bag.partition r A) where A is a constant bag
    * @return a partition of A based on the equivalence relation r
    */
-  static Node evaluateBagPartition(Rewriter *rewriter, TNode n);
+  static Node evaluateBagPartition(Rewriter* rewriter, TNode n);
+
+  /**
+   * @param n has the form ((_ table.aggr n1 ... n_k) f initial A)
+   * where initial and A are constants
+   * @return the aggregation result.
+   */
+  static Node evaluateTableAggregate(Rewriter* rewriter, TNode n);
 
   /**
    * @param n has the form (bag.filter p A) where A is a constant bag
@@ -117,6 +124,14 @@ class BagsUtils
    */
   static Node evaluateProduct(TNode n);
 
+  /**
+   * @param n of the form ((_ table.join (m_1 n_1 ... m_k n_k) ) A B) where
+   * A, B are constants
+   * @return the evaluation of inner joining tables A B on columns (m_1, n_1,
+   * ..., m_k, n_k)
+   */
+  static Node evaluateJoin(Rewriter* rewriter, TNode n);
+
   /**
    * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a
    * constant
@@ -124,6 +139,14 @@ class BagsUtils
    */
   static Node evaluateTableProject(TNode n);
 
+  /**
+   * @param n has the form ((_ table.join m1 n1 ... mk nk) A B)) where A, B are
+   * tables and m1 n1 ... mk nk are indices
+   * @return the pair <[m1 ... mk], [n1 ... nk]>
+   */
+  static std::pair<std::vector<uint32_t>, std::vector<uint32_t>>
+  splitTableJoinIndices(Node n);
+
  private:
   /**
    * a high order helper function that return a constant bag that is the result
index 18a82bb7b86d266d67c0a2547d42fb65ff7f8dab..1db567d54bb81653d17296d770796bd2b53b563d 100644 (file)
@@ -34,7 +34,7 @@ namespace theory {
 namespace bags {
 
 CardSolver::CardSolver(Env& env, SolverState& s, InferenceManager& im)
-    : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im), d_bagReduction(env)
+    : EnvObj(env), d_state(s), d_ig(&s, &im), d_im(im)
 {
   d_nm = NodeManager::currentNM();
   d_zero = d_nm->mkConstInt(Rational(0));
@@ -238,7 +238,7 @@ void CardSolver::addChildren(const Node& premise,
 
       Node card = d_nm->mkNode(BAG_CARD, parent);
       std::vector<Node> asserts;
-      Node reduced = d_bagReduction.reduceCardOperator(card, asserts);
+      Node reduced = BagReduction::reduceCardOperator(card, asserts);
       asserts.push_back(card.eqNode(reduced));
       InferInfo inferInfo(&d_im, InferenceId::BAGS_CARD);
       inferInfo.d_premises.push_back(premise);
index c72e8eb98ea48ec4d4f59550bf33c9077a78114e..8802564e6055aa97113f66c88f6e1893871de122 100644 (file)
@@ -119,9 +119,6 @@ class CardSolver : protected EnvObj
   InferenceManager& d_im;
   NodeManager* d_nm;
 
-  /** bag reduction */
-  BagReduction d_bagReduction;
-
   /**
    * A map from bag representatives to sets of bag representatives with the
    * invariant that each key is the disjoint union of each set in the value.
index 2000443a2b292d8166ff9d093f302b3bbff27589..715c3266225457f53d86ca8453ea664b314304f7 100644 (file)
@@ -23,6 +23,7 @@
 #include "theory/bags/bags_utils.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/solver_state.h"
+#include "theory/bags/table_project_op.h"
 #include "theory/datatypes/tuple_utils.h"
 #include "theory/quantifiers/fmf/bounded_integers.h"
 #include "theory/uf/equality_engine.h"
@@ -590,6 +591,9 @@ InferInfo InferenceGenerator::productUp(Node n, Node e1, Node e2)
   Node countA = getMultiplicityTerm(e1, A);
   Node countB = getMultiplicityTerm(e2, B);
 
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one));
+
   Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(tuple, skolem);
 
@@ -625,9 +629,91 @@ InferInfo InferenceGenerator::productDown(Node n, Node e)
 
   Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one));
+
+  Node multiply = d_nm->mkNode(MULT, countA, countB);
+  inferInfo.d_conclusion = count.eqNode(multiply);
+
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::joinUp(Node n, Node e1, Node e2)
+{
+  Assert(n.getKind() == TABLE_JOIN);
+  Node A = n[0];
+  Node B = n[1];
+  Node tuple = BagsUtils::constructProductTuple(n, e1, e2);
+
+  std::vector<Node> aElements = TupleUtils::getTupleElements(e1);
+  std::vector<Node> bElements = TupleUtils::getTupleElements(e2);
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<TableJoinOp>().getIndices();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_UP);
+
+  for (size_t i = 0; i < indices.size(); i += 2)
+  {
+    Node x = aElements[indices[i]];
+    Node y = bElements[indices[i + 1]];
+    Node equal = x.eqNode(y);
+    inferInfo.d_premises.push_back(equal);
+  }
+
+  Node countA = getMultiplicityTerm(e1, A);
+  Node countB = getMultiplicityTerm(e2, B);
+
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countA, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, countB, d_one));
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count = getMultiplicityTerm(tuple, skolem);
 
   Node multiply = d_nm->mkNode(MULT, countA, countB);
   inferInfo.d_conclusion = count.eqNode(multiply);
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::joinDown(Node n, Node e)
+{
+  Assert(n.getKind() == TABLE_JOIN);
+  Assert(e.getType().isSubtypeOf(n.getType().getBagElementType()));
+
+  Node A = n[0];
+  Node B = n[1];
+
+  TypeNode tupleBType = B.getType().getBagElementType();
+  TypeNode tupleAType = A.getType().getBagElementType();
+  size_t tupleALength = tupleAType.getTupleLength();
+  size_t productTupleLength = n.getType().getBagElementType().getTupleLength();
+
+  std::vector<Node> elements = TupleUtils::getTupleElements(e);
+  Node a = TupleUtils::constructTupleFromElements(
+      tupleAType, elements, 0, tupleALength - 1);
+  Node b = TupleUtils::constructTupleFromElements(
+      tupleBType, elements, tupleALength, productTupleLength - 1);
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_JOIN_DOWN);
+
+  Node countA = getMultiplicityTerm(a, A);
+  Node countB = getMultiplicityTerm(b, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count = getMultiplicityTerm(e, skolem);
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count, d_one));
+
+  Node multiply = d_nm->mkNode(MULT, countA, countB);
+  Node multiplicityConstraint = count.eqNode(multiply);
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<TableJoinOp>().getIndices();
+  Node joinConstraints = d_true;
+  for (size_t i = 0; i < indices.size(); i += 2)
+  {
+    Node x = elements[indices[i]];
+    Node y = elements[tupleALength + indices[i + 1]];
+    Node equal = x.eqNode(y);
+    joinConstraints = joinConstraints.andNode(equal);
+  }
+  inferInfo.d_conclusion = joinConstraints.andNode(multiplicityConstraint);
 
   return inferInfo;
 }
index c0143c16c2964652eba50459fab20d6cf7ff94c4..7d92ea4b028c70e263cc344ed2f97de18d39b052 100644 (file)
@@ -306,28 +306,60 @@ class InferenceGenerator
   InferInfo filterUpwards(Node n, Node e);
 
   /**
-   * @param n is a (table.product A B) where A, B are bags of tuples
+   * @param n is a (table.product A B) where A, B are tables
    * @param e1 an element of the form (tuple a1 ... am)
    * @param e2 an element of the form (tuple b1 ... bn)
    * @return  an inference that represents the following
-   * (=
-   *   (bag.count (tuple a1 ... am b1 ... bn) skolem)
-   *   (* (bag.count e1 A) (bag.count e2 B)))
+   * (=> (and (bag.member e1 A) (bag.member e2 B))
+   *     (=
+   *       (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *       (* (bag.count e1 A) (bag.count e2 B))))
    * where skolem is a variable equals (bag.product A B)
    */
   InferInfo productUp(Node n, Node e1, Node e2);
 
   /**
-   * @param n is a (table.product A B) where A, B are bags of tuples
+   * @param n is a (table.product A B) where A, B are tables
    * @param e an element of the form (tuple a1 ... am b1 ... bn)
    * @return an inference that represents the following
-   * (=
-   *   (bag.count (tuple a1 ... am b1 ... bn) skolem)
-   *   (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B)))
+   * (=> (bag.member e skolem)
+   *   (=
+   *     (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *     (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))))
    * where skolem is a variable equals (bag.product A B)
    */
   InferInfo productDown(Node n, Node e);
 
+  /**
+   * @param n is a ((_ table.join m1 n1 ... mk nk) A B) where A, B are tables
+   * @param e1 an element of the form (tuple a1 ... am)
+   * @param e2 an element of the form (tuple b1 ... bn)
+   * @return  an inference that represents the following
+   * (=> (and
+   *       (bag.member e1 A)
+   *       (bag.member e2 B)
+   *       (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk}))
+   *     (=
+   *       (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *       (* (bag.count e1 A) (bag.count e2 B))))
+   * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B)
+   */
+  InferInfo joinUp(Node n, Node e1, Node e2);
+
+  /**
+   * @param n is a (table.product A B) where A, B are tables
+   * @param e an element of the form (tuple a1 ... am b1 ... bn)
+   * @return an inference that represents the following
+   * (=> (bag.member e skolem)
+   *   (and
+   *     (= a_{m1} b_{n1}) ... (= a_{mk} b_{nk})
+   *     (=
+   *       (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *       (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))))
+   * where skolem is a variable equals ((_ table.join m1 n1 ... mk nk) A B)
+   */
+  InferInfo joinDown(Node n, Node e);
+
   /**
    * @param element of type T
    * @param bag of type (bag T)
index 1e875e99880738ab921d0a8f4dad25ac10e570a4..ecde27c625e3e58e94b8ed571e7f6fb11c10b544 100644 (file)
@@ -132,8 +132,32 @@ constant TABLE_PROJECT_OP \
 
 parameterized TABLE_PROJECT TABLE_PROJECT_OP 1 "table projection"
 
+# table.aggregate operator
+constant TABLE_AGGREGATE_OP \
+  class \
+  TableAggregateOp \
+  ::cvc5::internal::TableAggregateOpHashFunction \
+  "theory/bags/table_project_op.h" \
+  "operator for TABLE_AGGREGATE; payload is an instance of the cvc5::internal::TableAggregateOp class"
+
+parameterized TABLE_AGGREGATE TABLE_AGGREGATE_OP 3 "table aggregate"
+
+# table.join operator
+constant TABLE_JOIN_OP \
+  class \
+  TableJoinOp \
+  ::cvc5::internal::TableJoinOpHashFunction \
+  "theory/bags/table_project_op.h" \
+  "operator for TABLE_JOIN; payload is an instance of the cvc5::internal::TableJoinOp class"
+
+parameterized TABLE_JOIN TABLE_JOIN_OP 2 "table join"
+
 typerule TABLE_PRODUCT              ::cvc5::internal::theory::bags::TableProductTypeRule
 typerule TABLE_PROJECT_OP           "SimpleTypeRule<RBuiltinOperator>"
 typerule TABLE_PROJECT              ::cvc5::internal::theory::bags::TableProjectTypeRule
+typerule TABLE_AGGREGATE_OP         "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_AGGREGATE            ::cvc5::internal::theory::bags::TableAggregateTypeRule
+typerule TABLE_JOIN_OP              "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_JOIN                 ::cvc5::internal::theory::bags::TableJoinTypeRule
 
 endtheory
index 0c634351af1034484c2725ca27b8f528c73891ed..17d1d8f9a71ac68b6eb01bc88f955cb13b9934ea 100644 (file)
@@ -26,6 +26,7 @@ const char* toString(Rewrite r)
   switch (r)
   {
     case Rewrite::NONE: return "NONE";
+    case Rewrite::AGGREGATE_CONST: return "AGGREGATE_CONST";
     case Rewrite::BAG_MAKE_COUNT_NEGATIVE: return "BAG_MAKE_COUNT_NEGATIVE";
     case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT";
     case Rewrite::CARD_BAG_MAKE: return "CARD_BAG_MAKE";
index 461ea8703d3df36296d770c60772b4d835362821..eb42e053d7c315bdd1ca37f6c1482653d989c2e3 100644 (file)
@@ -31,6 +31,7 @@ namespace bags {
 enum class Rewrite : uint32_t
 {
   NONE,  // no rewrite happened
+  AGGREGATE_CONST,
   BAG_MAKE_COUNT_NEGATIVE,
   CARD_DISJOINT,
   CARD_BAG_MAKE,
index 426753d8a190f7eb5aea2cb47ba95f1310d2f86e..72700be9d36dd8278e1b3713d04cb738ad852a32 100644 (file)
@@ -22,4 +22,14 @@ TableProjectOp::TableProjectOp(std::vector<uint32_t> indices)
 {
 }
 
+TableAggregateOp::TableAggregateOp(std::vector<uint32_t> indices)
+    : ProjectOp(std::move(indices))
+{
+}
+
+TableJoinOp::TableJoinOp(std::vector<uint32_t> indices)
+    : ProjectOp(std::move(indices))
+{
+}
+
 }  // namespace cvc5::internal
index a061537e90b2021ed5602e81ca744babfb9d50c1..10c45f915b74e942139f56df3ce803f67832e7d8 100644 (file)
@@ -34,11 +34,41 @@ class TableProjectOp : public ProjectOp
 }; /* class TableProjectOp */
 
 /**
- * Hash function for the TupleProjectOpHashFunction objects.
+ * Hash function for the TableProjectOpHashFunction objects.
  */
 struct TableProjectOpHashFunction : public ProjectOpHashFunction
 {
-}; /* struct TupleProjectOpHashFunction */
+}; /* struct TableProjectOpHashFunction */
+
+class TableAggregateOp : public ProjectOp
+{
+ public:
+  explicit TableAggregateOp(std::vector<uint32_t> indices);
+  TableAggregateOp(const TableAggregateOp& op) = default;
+}; /* class TableAggregateOp */
+
+/**
+ * Hash function for the TableAggregateOpHashFunction objects.
+ */
+struct TableAggregateOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TableAggregateOpHashFunction */
+
+
+class TableJoinOp : public ProjectOp
+{
+ public:
+  explicit TableJoinOp(std::vector<uint32_t> indices);
+  TableJoinOp(const TableJoinOp& op) = default;
+}; /* class TableJoinOp */
+
+/**
+ * Hash function for the TableJoinOpHashFunction objects.
+ */
+struct TableJoinOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TableJoinOpHashFunction */
+
 
 }  // namespace cvc5::internal
 
index adcf3d468c1b6849cc1c30fa6db10c19cb491e73..3fa491a7c8f976bc3cb05ede91ce1ab744c1bbd5 100644 (file)
@@ -41,8 +41,7 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation)
       d_rewriter(env.getRewriter(), &d_statistics.d_rewrites),
       d_termReg(env, d_state, d_im),
       d_solver(env, d_state, d_im, d_termReg),
-      d_cardSolver(env, d_state, d_im),
-      d_bagReduction(env)
+      d_cardSolver(env, d_state, d_im)
 {
   // use the official theory state and inference manager objects
   d_theoryState = &d_state;
@@ -83,6 +82,8 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(BAG_PARTITION);
   d_equalityEngine->addFunctionKind(TABLE_PRODUCT);
   d_equalityEngine->addFunctionKind(TABLE_PROJECT);
+  d_equalityEngine->addFunctionKind(TABLE_AGGREGATE);
+  d_equalityEngine->addFunctionKind(TABLE_JOIN);
 }
 
 TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
@@ -95,7 +96,7 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
     case kind::BAG_FOLD:
     {
       std::vector<Node> asserts;
-      Node ret = d_bagReduction.reduceFoldOperator(atom, asserts);
+      Node ret = BagReduction::reduceFoldOperator(atom, asserts);
       NodeManager* nm = NodeManager::currentNM();
       Node andNode = nm->mkNode(AND, asserts);
       d_im.lemma(andNode, InferenceId::BAGS_FOLD);
@@ -104,6 +105,12 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
                          << andNode << std::endl;
       return TrustNode::mkTrustRewrite(atom, ret, nullptr);
     }
+    case kind::TABLE_AGGREGATE:
+    {
+      Node ret = BagReduction::reduceAggregateOperator(atom);
+      Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl;
+      return TrustNode::mkTrustRewrite(atom, ret, nullptr);
+    }
     default: return TrustNode::null();
   }
 }
index 8259a2392f9c0dc3afb8c5e082c8e65d5bc081bb..9c95f991ea2b9a191280d937b9cd2a30a0e1ea32 100644 (file)
@@ -130,9 +130,6 @@ class TheoryBags : public Theory
   /** the main solver for bags */
   CardSolver d_cardSolver;
 
-  /** bag reduction */
-  BagReduction d_bagReduction;
-
   /** The representation of the strategy */
   Strategy d_strat;
 
index e786a6afc68818e0567e417639bfcd9fa1f575d0..5ca0f081538a008b416c379b69ab38f79f0982a0 100644 (file)
@@ -33,6 +33,8 @@ namespace cvc5::internal {
 namespace theory {
 namespace bags {
 
+using namespace datatypes;
+
 TypeNode BinaryOperatorTypeRule::computeType(NodeManager* nodeManager,
                                              TNode n,
                                              bool check)
@@ -530,16 +532,8 @@ TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
     throw TypeCheckingExceptionPrivate(n, ss.str());
   }
 
-  std::vector<TypeNode> productTupleTypes;
-  std::vector<TypeNode> tupleATypes = elementAType.getTupleTypes();
-  std::vector<TypeNode> tupleBTypes = elementBType.getTupleTypes();
-
-  productTupleTypes.insert(
-      productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end());
-  productTupleTypes.insert(
-      productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end());
-
-  TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes);
+  TypeNode retTupleType =
+      TupleUtils::concatTupleTypes(elementAType, elementBType);
   TypeNode retType = nodeManager->mkBagType(retTupleType);
   return retType;
 }
@@ -595,7 +589,140 @@ TypeNode TableProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check)
   }
   TypeNode tupleType = bagType.getBagElementType();
   TypeNode retTupleType =
-      datatypes::TupleUtils::getTupleProjectionType(indices, tupleType);
+      TupleUtils::getTupleProjectionType(indices, tupleType);
+  return nm->mkBagType(retTupleType);
+}
+
+TypeNode TableAggregateTypeRule::computeType(NodeManager* nm,
+                                             TNode n,
+                                             bool check)
+{
+  Assert(n.getKind() == kind::TABLE_AGGREGATE && n.hasOperator()
+         && n.getOperator().getKind() == kind::TABLE_AGGREGATE_OP);
+  TableAggregateOp op = n.getOperator().getConst<TableAggregateOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+
+  TypeNode functionType = n[0].getType(check);
+  TypeNode initialValueType = n[1].getType(check);
+  TypeNode bagType = n[2].getType(check);
+
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      std::stringstream ss;
+      ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = bagType.getBagElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TupleUtils::checkTypeIndices(n, tupleType, indices);
+
+    TypeNode elementType = bagType.getBagElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T T) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    TypeNode rangeType = functionType.getRangeType();
+    if (!(argTypes.size() == 2 && argTypes[0] == elementType
+          && argTypes[1] == rangeType))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T T). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    if (rangeType != initialValueType)
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects an initial value of type "
+         << rangeType << ". Found a term of type '" << initialValueType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  return nm->mkBagType(functionType.getRangeType());
+}
+
+TypeNode TableJoinTypeRule::computeType(NodeManager* nm, TNode n, bool check)
+{
+  Assert(n.getKind() == kind::TABLE_JOIN && n.hasOperator()
+         && n.getOperator().getKind() == kind::TABLE_JOIN_OP);
+  TableJoinOp op = n.getOperator().getConst<TableJoinOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+  Node A = n[0];
+  Node B = n[1];
+  TypeNode aType = A.getType();
+  TypeNode bType = B.getType();
+
+  if (check)
+  {
+    if (!(aType.isBag() && bType.isBag()))
+    {
+      std::stringstream ss;
+      ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '"
+         << n[1] << "' of types '" << aType << "', '" << bType
+         << "' respectively. ";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode aTupleType = aType.getBagElementType();
+    TypeNode bTupleType = bType.getBagElementType();
+    if (!(aTupleType.isTuple() && bTupleType.isTuple()))
+    {
+      std::stringstream ss;
+      ss << "TABLE_JOIN operator expects two tables. Found '" << n[0] << "', '"
+         << n[1] << "' of types '" << aType << "', '" << bType
+         << "' respectively. ";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    if (indices.size() % 2 != 0)
+    {
+      std::stringstream ss;
+      ss << "TABLE_JOIN operator expects even number of indices. Found "
+         << indices.size() << " in term " << n;
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    auto [aIndices, bIndices] = BagsUtils::splitTableJoinIndices(n);
+    TupleUtils::checkTypeIndices(n, aTupleType, aIndices);
+    TupleUtils::checkTypeIndices(n, bTupleType, bIndices);
+
+    // check the types of columns
+    std::vector<TypeNode> aTypes = aTupleType.getTupleTypes();
+    std::vector<TypeNode> bTypes = bTupleType.getTupleTypes();
+    for (uint32_t i = 0; i < aIndices.size(); i++)
+    {
+      if (aTypes[aIndices[i]] != bTypes[bIndices[i]])
+      {
+        std::stringstream ss;
+        ss << "TABLE_JOIN operator expects column " << aIndices[i]
+           << " in table " << n[0] << " to match column " << bIndices[i]
+           << " in table " << n[1] << ". But their types are "
+           << aTypes[aIndices[i]] << " and " << bTypes[bIndices[i]]
+           << "' respectively. ";
+        throw TypeCheckingExceptionPrivate(n, ss.str());
+      }
+    }
+  }
+  TypeNode aTupleType = aType.getBagElementType();
+  TypeNode bTupleType = bType.getBagElementType();
+  TypeNode retTupleType = TupleUtils::concatTupleTypes(aTupleType, bTupleType);
   return nm->mkBagType(retTupleType);
 }
 
index 04e5bfd04f8b295e8f8917bcb0fc29e3f9339fce..445d627db4a4d2b4ada5db1e3b482da746263bd4 100644 (file)
@@ -160,7 +160,8 @@ struct BagFoldTypeRule
 }; /* struct BagFoldTypeRule */
 
 /**
- * Type rule for (bag.partition r A) to make sure r is a binary operation of type
+ * Type rule for (bag.partition r A) to make sure r is a binary operation of
+ * type
  * (-> T1 T1 Bool), and A is a bag of type (Bag T1)
  */
 struct BagPartitionTypeRule
@@ -188,6 +189,33 @@ struct TableProjectTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagFoldTypeRule */
 
+/**
+ * Table aggregate operator is indexed by a list of indices (n_1, ..., n_k).
+ * It ensures that it has 3 arguments:
+ * - A combining function of type (-> (Tuple T_1 ... T_j) T T)
+ * - Initial value of type T
+ * - A table of type (Table T_1 ... T_j) where 0 <= n_1, ..., n_k < j
+ * the returned type is (Bag T).
+ */
+struct TableAggregateTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct TableAggregateTypeRule */
+
+/**
+ * Table join operator is indexed by a list of indices (m_1, m_k, n_1, ...,
+ * n_k). It ensures that it has 2 arguments:
+ * - A table of type (Table X_1 ... X_i)
+ * - A table of type (Table Y_1 ... Y_j)
+ * such that indices has constraints 0 <= m_1, ..., mk, n_1, ..., n_k <=
+ * min(i,j) and types has constraints X_{m_1} = Y_{n_1}, ..., X_{m_k} = Y_{n_k}.
+ * The returned type is (Table X_1 ... X_i Y_1 ... Y_j)
+ */
+struct TableJoinTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct TableJoinTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index 05528a6442ed73d39af0f90ccb9d33b27f9869a1..838e840c3799306e90fefda6540645272696dd8c 100644 (file)
@@ -15,6 +15,8 @@
 
 #include "tuple_utils.h"
 
+#include <sstream>
+
 #include "expr/dtype.h"
 #include "expr/dtype_cons.h"
 
@@ -24,6 +26,41 @@ namespace cvc5::internal {
 namespace theory {
 namespace datatypes {
 
+void TupleUtils::checkTypeIndices(Node n,
+                                  TypeNode tupleType,
+                                  const std::vector<uint32_t> indices)
+{
+  // make sure all indices are less than the size of the tuple
+  DType dType = tupleType.getDType();
+  DTypeConstructor constructor = dType[0];
+  size_t numArgs = constructor.getNumArgs();
+  for (uint32_t index : indices)
+  {
+    std::stringstream ss;
+    if (index >= numArgs)
+    {
+      ss << "Index " << index << " in term " << n << " is > " << (numArgs - 1)
+         << " the maximum value ";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+}
+
+TypeNode TupleUtils::concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2)
+{
+  std::vector<TypeNode> concatTupleTypes;
+  std::vector<TypeNode> tuple1Types = tupleType1.getTupleTypes();
+  std::vector<TypeNode> tuple2Types = tupleType2.getTupleTypes();
+
+  concatTupleTypes.insert(
+      concatTupleTypes.end(), tuple1Types.begin(), tuple1Types.end());
+  concatTupleTypes.insert(
+      concatTupleTypes.end(), tuple2Types.begin(), tuple2Types.end());
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode ret = nm->mkTupleType(concatTupleTypes);
+  return ret;
+}
+
 Node TupleUtils::nthElementOfTuple(Node tuple, int n_th)
 {
   if (tuple.getKind() == APPLY_CONSTRUCTOR)
index 04112139737ef72098fbb19d399cd6c13413c2d2..9afbd59fe724d74f54216feed8500776b114786d 100644 (file)
@@ -25,6 +25,25 @@ namespace datatypes {
 class TupleUtils
 {
  public:
+  /**
+   *
+   * @param n a node to print in the message if TypeCheckingExceptionPrivate
+   * exception is thrown
+   * @param tupleType the type of the tuple
+   * @param indices a list of indices for projection
+   * @throw an exception if one of the indices in node n is greater than the
+   * expected tuple's length
+   */
+  static void checkTypeIndices(Node n,
+                               TypeNode tupleType,
+                               const std::vector<uint32_t> indices);
+  /**
+   * @param tupleType1 tuple type
+   * @param tupleType2 tuple type
+   * @return the type of concatenation of tupleType1, tupleType2
+   */
+  static TypeNode concatTupleTypes(TypeNode tupleType1, TypeNode tupleType2);
+
   /**
    * @param tuple a node of tuple type
    * @param n_th the index of the element to be extracted, and must satisfy the
index f298c6b852d048cb1cb6b150c47ec26494cd9373..fe4ff0efc1bfbc21fbe7e83be178dd0652d5b5de 100644 (file)
@@ -140,6 +140,8 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_CARD_EMPTY: return "BAGS_CARD_EMPTY";
     case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP";
     case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN";
+    case InferenceId::TABLES_JOIN_UP: return "TABLES_JOIN_UP";
+    case InferenceId::TABLES_JOIN_DOWN: return "TABLES_JOIN_DOWN";
 
     case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
     case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
index 034eeacc1ca3f5f436b4df0b26a49f6115cedebf..b56168e720056144f722234c0740f634bf2ae8c9 100644 (file)
@@ -206,6 +206,8 @@ enum class InferenceId
   BAGS_CARD_EMPTY,
   TABLES_PRODUCT_UP,
   TABLES_PRODUCT_DOWN,
+  TABLES_JOIN_UP,
+  TABLES_JOIN_DOWN,
   // ---------------------------------- end bags theory
 
   // ---------------------------------- bitvector theory
index 13324b7a4263b505c86400a26538fc79a6d1fd27..100138346b6af48ecdd39c9eaa398964d3c995c8 100644 (file)
@@ -1813,6 +1813,10 @@ set(regress_1_tests
   regress1/bags/proj-issue497.smt2
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
+  regress1/bags/table_aggregate1.smt2
+  regress1/bags/table_join1.smt2
+  regress1/bags/table_join2.smt2
+  regress1/bags/table_join3.smt2
   regress1/bags/table_project1.smt2
   regress1/bags/union_disjoint.smt2
   regress1/bags/union_max1.smt2
diff --git a/test/regress/cli/regress1/bags/table_aggregate1.smt2 b/test/regress/cli/regress1/bags/table_aggregate1.smt2
new file mode 100644 (file)
index 0000000..637def3
--- /dev/null
@@ -0,0 +1,31 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+(define-fun sumByCategory ((x (Tuple String String Int)) (y (Tuple String Int))) (Tuple String Int)
+  (tuple
+   ((_ tuple.select 0) x)
+   (+ ((_ tuple.select 2) x) ((_ tuple.select 1) y))))
+
+(declare-fun categorySales () (Bag (Tuple String Int)))
+
+;(define-fun categorySales () (Bag (Tuple String Int))
+;  (bag.union_disjoint
+;   (bag (tuple "Hardware" 10) 1)
+;   (bag (tuple "Software" 6) 1)))
+
+(assert
+ (= categorySales
+    ((_ table.aggr 0)
+      sumByCategory
+      (tuple "" 0)
+      (bag.union_disjoint
+       (bag (tuple "Software" "win" 1) 2)
+       (bag (tuple "Software" "mac" 4) 1)
+       (bag (tuple "Hardware" "cpu" 2) 2)
+       (bag (tuple "Hardware" "gpu" 3) 2)))))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_join1.smt2 b/test/regress/cli/regress1/bags/table_join1.smt2
new file mode 100644 (file)
index 0000000..349137f
--- /dev/null
@@ -0,0 +1,29 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+(assert
+ (= Departments
+    (bag.union_disjoint
+     (bag (tuple 1 "Computer") 1)
+     (bag (tuple 2 "Engineering") 1))))
+
+(assert
+ (= Students
+    (bag.union_disjoint
+     (bag (tuple 1 "A" 1) 1)
+     (bag (tuple 2 "B" 1) 1)
+     (bag (tuple 3 "C" 2) 1))))
+
+;(define-fun DepartmentStudents () (Bag (Tuple Int String Int String Int))
+;  (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1)
+;                      (bag (tuple 1 "Computer" 2 "B" 1) 1)
+;                      (bag (tuple 2 "Engineering" 3 "C" 2) 1)))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_join2.smt2 b/test/regress/cli/regress1/bags/table_join2.smt2
new file mode 100644 (file)
index 0000000..2c86e85
--- /dev/null
@@ -0,0 +1,28 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+;(define-fun Departments () (Bag (Tuple Int String))
+;  (bag.union_disjoint
+;   (bag (tuple 1 "Computer") 1)
+;   (bag (tuple 2 "Engineering") 1)))
+
+;(define-fun Students () (Bag (Tuple Int String Int))
+;  (bag.union_disjoint
+;   (bag (tuple 1 "A" 1) 1)
+;   (bag (tuple 2 "B" 1) 1)
+;   (bag (tuple 3 "C" 2) 1)))
+
+(assert
+ (= DepartmentStudents
+    (bag.union_disjoint (bag (tuple 1 "Computer" 1 "A" 1) 1)
+                        (bag (tuple 1 "Computer" 2 "B" 1) 1)
+                        (bag (tuple 2 "Engineering" 3 "C" 2) 1))))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_join3.smt2 b/test/regress/cli/regress1/bags/table_join3.smt2
new file mode 100644 (file)
index 0000000..aa416d0
--- /dev/null
@@ -0,0 +1,27 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun Departments () (Table Int String))
+(declare-fun Students () (Table Int String Int))
+(declare-fun DepartmentStudents () (Table Int String Int String Int))
+
+(declare-fun d1 () (Tuple Int String))
+(declare-fun d2 () (Tuple Int String))
+(assert (distinct d1 d2))
+
+(declare-fun s1 () (Tuple Int String Int))
+(declare-fun s2 () (Tuple Int String Int))
+(assert (distinct s1 s2))
+
+(assert
+  (distinct DepartmentStudents (as bag.empty (Table Int String Int String Int))))
+
+(assert (bag.member d1 Departments))
+(assert (bag.member d2 Departments))
+(assert (bag.member s1 Students))
+(assert (bag.member s2 Students))
+
+(assert (= DepartmentStudents ((_ table.join 0 2) Departments Students)))
+
+(check-sat)
index d144f1107a0699b6ad540f5831b09e31af1ecbbc..e6cc722a3cd0a31b1e31f78e2653be54e398e30b 100644 (file)
@@ -56,6 +56,8 @@ class TestTheoryWhiteBagsNormalForm : public TestSmt
     return elements;
   }
 
+  Node rewrite(Node n) { return d_rewriter->postRewrite(n).d_node; }
+
   std::unique_ptr<BagsRewriter> d_rewriter;
 };
 
@@ -64,7 +66,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, empty_bag_normal_form)
   Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
   // empty bags are in normal form
   ASSERT_TRUE(emptybag.isConst());
-  Node n = BagsUtils::evaluate(emptybag);
+  Node n = rewrite(emptybag);
   ASSERT_EQ(emptybag, n);
 }
 
@@ -85,9 +87,9 @@ TEST_F(TestTheoryWhiteBagsNormalForm, mkBag_constant_element)
 
   ASSERT_FALSE(negative.isConst());
   ASSERT_FALSE(zero.isConst());
-  ASSERT_EQ(emptybag, BagsUtils::evaluate(negative));
-  ASSERT_EQ(emptybag, BagsUtils::evaluate(zero));
-  ASSERT_EQ(positive, BagsUtils::evaluate(positive));
+  ASSERT_EQ(emptybag, rewrite(negative));
+  ASSERT_EQ(emptybag, rewrite(zero));
+  ASSERT_EQ(positive, rewrite(positive));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
@@ -116,25 +118,25 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
 
   Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty);
   Node output1 = zero;
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5);
   Node output2 = zero;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4);
   Node output3 = four;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY);
   Node output4 = four;
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 
   Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5);
   Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ);
   Node output5 = zero;
-  ASSERT_EQ(output4, BagsUtils::evaluate(input4));
+  ASSERT_EQ(output4, rewrite(input4));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
@@ -151,7 +153,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag);
   Node output1 = emptybag;
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -168,12 +170,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
 
   Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4);
   Node output2 = x_1;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag);
   Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
@@ -213,7 +215,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, BagsUtils::evaluate(input));
+  ASSERT_EQ(output, rewrite(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
@@ -234,12 +236,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
   Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B);
   // unionDisjointAB is already in a normal form
   ASSERT_TRUE(unionDisjointAB.isConst());
-  ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB));
+  ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointAB));
 
   Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A);
   // unionDisjointAB is the normal form of unionDisjointBA
   ASSERT_FALSE(unionDisjointBA.isConst());
-  ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA));
+  ASSERT_EQ(unionDisjointAB, rewrite(unionDisjointBA));
 
   Node unionDisjointAB_C =
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C);
@@ -249,7 +251,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
   // unionDisjointA_BC is the normal form of unionDisjointAB_C
   ASSERT_FALSE(unionDisjointAB_C.isConst());
   ASSERT_TRUE(unionDisjointA_BC.isConst());
-  ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C));
+  ASSERT_EQ(unionDisjointA_BC, rewrite(unionDisjointAB_C));
 
   Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A);
   Node AA = d_nodeManager->mkBag(d_nodeManager->stringType(),
@@ -257,7 +259,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
                                  d_nodeManager->mkConstInt(Rational(4)));
   ASSERT_FALSE(unionDisjointAA.isConst());
   ASSERT_TRUE(AA.isConst());
-  ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA));
+  ASSERT_EQ(AA, rewrite(unionDisjointAA));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
@@ -297,7 +299,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, BagsUtils::evaluate(input));
+  ASSERT_EQ(output, rewrite(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
@@ -332,7 +334,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
   Node output = x_3;
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, BagsUtils::evaluate(input));
+  ASSERT_EQ(output, rewrite(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
@@ -369,7 +371,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
   Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2);
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, BagsUtils::evaluate(input));
+  ASSERT_EQ(output, rewrite(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
@@ -406,7 +408,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
   Node output = z_2;
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, BagsUtils::evaluate(input));
+  ASSERT_EQ(output, rewrite(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
@@ -429,16 +431,16 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
   Node input1 = d_nodeManager->mkNode(BAG_CARD, empty);
   Node output1 = d_nodeManager->mkConstInt(Rational(0));
 
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4);
   Node output2 = d_nodeManager->mkConstInt(Rational(4));
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1);
   Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint);
   Node output3 = d_nodeManager->mkConstInt(Rational(5));
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
@@ -466,20 +468,20 @@ TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
 
   Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty);
   Node output1 = falseNode;
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1);
   Node output2 = trueNode;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4);
   Node output3 = falseNode;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
   Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint);
   Node output4 = falseNode;
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
@@ -497,7 +499,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset);
   Node output1 = emptybag;
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -512,13 +514,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
 
   Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton);
   Node output2 = x_1;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   // for normal sets, the first node is the largest, not smallest
   Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
   Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet);
   Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
@@ -536,7 +538,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag);
   Node output1 = emptyset;
-  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
+  ASSERT_EQ(output1, rewrite(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -551,13 +553,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
 
   Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4);
   Node output2 = xSingleton;
-  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
+  ASSERT_EQ(output2, rewrite(input2));
 
   // for normal sets, the first node is the largest, not smallest
   Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag);
   Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
-  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
+  ASSERT_EQ(output3, rewrite(input3));
 }
 }  // namespace test
 }  // namespace cvc5::internal