Add table.project evaluator (#8632)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 19 Apr 2022 18:49:14 +0000 (13:49 -0500)
committerGitHub <noreply@github.com>
Tue, 19 Apr 2022 18:49:14 +0000 (18:49 +0000)
This PR

adds evaluator for table.project operator
updates the parser to interpret "Table" as a table with zero columns

23 files changed:
src/CMakeLists.txt
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/bags/kinds
src/theory/bags/table_project_op.cpp [new file with mode: 0644]
src/theory/bags/table_project_op.h [new file with mode: 0644]
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/theory_datatypes_type_rules.cpp
src/theory/datatypes/theory_datatypes_type_rules.h
src/theory/datatypes/tuple_project_op.cpp
src/theory/datatypes/tuple_project_op.h
src/theory/datatypes/tuple_utils.cpp
src/theory/datatypes/tuple_utils.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/table_project1.smt2 [new file with mode: 0644]

index f414db97d4752ab3aaff0270137b7d67906bb2c1..ea10d8fcf947bf44a1c82929605f14b973546f6a 100644 (file)
@@ -571,6 +571,8 @@ libcvc5_add_sources(
   theory/bags/solver_state.h
   theory/bags/strategy.cpp
   theory/bags/strategy.h
+  theory/bags/table_project_op.cpp
+  theory/bags/table_project_op.h
   theory/bags/term_registry.cpp
   theory/bags/term_registry.h
   theory/bags/theory_bags.cpp
index 5de54018cb81d4bcd7bb1906ceebb70f2de82015..d3c28aa06a807cea4eb094140ef5b5e6b2ba216d 100644 (file)
@@ -69,6 +69,7 @@
 #include "smt/model.h"
 #include "smt/smt_mode.h"
 #include "smt/solver_engine.h"
+#include "theory/bags/table_project_op.h"
 #include "theory/datatypes/tuple_project_op.h"
 #include "theory/logic_info.h"
 #include "theory/theory_model.h"
@@ -328,6 +329,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(BAG_FILTER, internal::Kind::BAG_FILTER),
         KIND_ENUM(BAG_FOLD, internal::Kind::BAG_FOLD),
         KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT),
+        KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
         /* Strings ---------------------------------------------------------- */
         KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT),
         KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP),
@@ -643,6 +645,8 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::BAG_FILTER, BAG_FILTER},
         {internal::Kind::BAG_FOLD, BAG_FOLD},
         {internal::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
+        {internal::Kind::TABLE_PROJECT, TABLE_PROJECT},
+        {internal::Kind::TABLE_PROJECT_OP, TABLE_PROJECT},
         /* Strings --------------------------------------------------------- */
         {internal::Kind::STRING_CONCAT, STRING_CONCAT},
         {internal::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
@@ -6215,6 +6219,9 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
     case TUPLE_PROJECT:
       res = mkOpHelper(kind, internal::TupleProjectOp(args));
       break;
+    case TABLE_PROJECT:
+      res = mkOpHelper(kind, internal::TableProjectOp(args));
+      break;
     default:
       if (nargs == 0)
       {
index e1ba6740589e573561214de88f674241755330d7..bbc5cfdd8bf65af895a8ee18de504425fcaed07d 100644 (file)
@@ -3695,7 +3695,22 @@ enum Kind : int32_t
    * \endrst
    */
   TABLE_PRODUCT,
-
+  /**
+   * Table projection operator extends tuple projection operator to tables.
+   *
+   * - Arity: ``1``
+   *   - ``1:`` Term of tuple Sort
+   *
+   * - Indices: ``n``
+   *   - ``1..n:`` The table indices to project
+   *
+   * - Create Term of this Kind with:
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * - Create Op of this kind with:
+   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   */
+  TABLE_PROJECT,
   /* Strings --------------------------------------------------------------- */
 
   /**
index 224da51f44aa6b82d9559cdf16ec739c336ba06e..e66636e22d3f7bcdeb385abdd66075157559a5d1 100644 (file)
@@ -1448,6 +1448,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::TUPLE_PROJECT, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK TABLE_PROJECT_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_PROJECT, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1589,6 +1595,13 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::TUPLE_PROJECT;
         p.d_op = SOLVER->mkOp(cvc5::TUPLE_PROJECT, numerals);
       }
+    | TABLE_PROJECT_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ table.project i_1 ... i_n) where
+        // i_1, ..., i_n are numerals
+        p.d_kind = cvc5::TABLE_PROJECT;
+        p.d_op = SOLVER->mkOp(cvc5::TABLE_PROJECT, numerals);
+       }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2197,6 +2210,7 @@ FORALL_TOK        : 'forall';
 CHAR_TOK : { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_STRINGS) }? 'char';
 TUPLE_CONST_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple';
 TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple.project';
+TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index 41a98c61398c7ccaa8208dcd11eb41b4848bfba7..05a1e19823104f27cfc8480c0a3f677218d332c5 100644 (file)
@@ -258,6 +258,7 @@ void Smt2::addSepOperators() {
 void Smt2::addCoreSymbols()
 {
   defineType("Bool", d_solver->getBooleanSort(), true);
+  defineType("Table", d_solver->mkBagSort(d_solver->mkTupleSort({})), true);
   defineVar("true", d_solver->mkTrue(), true);
   defineVar("false", d_solver->mkFalse(), true);
   addOperator(cvc5::AND, "and");
@@ -1123,7 +1124,7 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
     Trace("parser") << "applyParseOp: return selector " << ret << std::endl;
     return ret;
   }
-  else if (p.d_kind == cvc5::TUPLE_PROJECT)
+  else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index 91a3f21868f665bed02701e6c122957af249b839..7d5dc1cc524594aec125dbdb18527539ae9d89e7 100644 (file)
@@ -42,6 +42,7 @@
 #include "proof/unsat_core.h"
 #include "smt/command.h"
 #include "smt_util/boolean_simplification.h"
+#include "theory/bags/table_project_op.h"
 #include "theory/arrays/theory_arrays_rewriter.h"
 #include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/datatypes/tuple_project_op.h"
@@ -792,6 +793,21 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::TABLE_PROJECT:
+  {
+    TableProjectOp op = n.getOperator().getConst<TableProjectOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (table.project A)
+      out << "table.project " << n[0] << ")";
+    }
+    else
+    {
+      // e.g. ((_ table.project 2 4 4) A)
+      out << "(_ table.project" << op << ") " << n[0] << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -1168,6 +1184,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_FILTER: return "bag.filter";
   case kind::BAG_FOLD: return "bag.fold";
   case kind::TABLE_PRODUCT: return "table.product";
+  case kind::TABLE_PROJECT: return "table.project";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
index 3c5089943ef672602a94cf8f033661e9e87f3b9c..e719232486101bab5c6c79098b8e53cdc28719f8 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/dtype_cons.h"
 #include "expr/emptybag.h"
 #include "smt/logic_exception.h"
+#include "table_project_op.h"
 #include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
 #include "theory/type_enumerator.h"
@@ -141,6 +142,7 @@ Node BagsUtils::evaluate(TNode n)
     case BAG_FILTER: return evaluateBagFilter(n);
     case BAG_FOLD: return evaluateBagFold(n);
     case TABLE_PRODUCT: return evaluateProduct(n);
+    case TABLE_PROJECT: return evaluateTableProject(n);
     default: break;
   }
   Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
@@ -829,6 +831,36 @@ Node BagsUtils::evaluateProduct(TNode n)
   return ret;
 }
 
+Node BagsUtils::evaluateTableProject(TNode n)
+{
+  Assert(n.getKind() == TABLE_PROJECT);
+  // Examples
+  // --------
+  // - ((_ table.project 1) (bag (tuple true "a") 4)) = (bag (tuple "a") 4)
+  // - (table.project (bag.union_disjoint
+  //                    (bag (tuple "a") 4)
+  //                    (bag (tuple "b") 3))) = (bag tuple 7)
+
+  Node A = n[0];
+
+  std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
+
+  std::map<Node, Rational> elements;
+  std::vector<uint32_t> indices =
+      n.getOperator().getConst<TableProjectOp>().getIndices();
+
+  for (const auto& [a, countA] : elementsA)
+  {
+    Node element = TupleUtils::getTupleProjection(indices, a);
+    // multiple elements could be projected to the same tuple.
+    // Zero is the default value for Rational values.
+    elements[element] += countA;
+  }
+
+  Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
+  return ret;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index 41b8c14d3acce78a7c05caaf8308d985bb5707ec..42e7b0caf066d7b35827a4010c49d3db746dbd7b 100644 (file)
@@ -109,6 +109,13 @@ class BagsUtils
    */
   static Node evaluateProduct(TNode n);
 
+  /**
+   * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a
+   * constant
+   * @return the evaluation of the projection
+   */
+  static Node evaluateTableProject(TNode n);
+
  private:
   /**
    * a high order helper function that return a constant bag that is the result
index 49e1d56247a53e438b9508cf132ae788249c7572..49bca83fbbddab2cb161bc9cb4db0d2ab0b571a6 100644 (file)
@@ -114,9 +114,21 @@ construle BAG_UNION_DISJOINT     ::cvc5::internal::theory::bags::BinaryOperatorT
 construle BAG_MAKE               ::cvc5::internal::theory::bags::BagMakeTypeRule
 
 
-# bag.product operator returns the cross product of two tables
+# table.product operator returns the cross product of two tables
 operator TABLE_PRODUCT             2 "table cross product"
 
+# table.project operator extends datatypes tuple_project operator to a bag of tuples
+constant TABLE_PROJECT_OP \
+  class \
+  TableProjectOp \
+  ::cvc5::internal::TableProjectOpHashFunction \
+  "theory/bags/table_project_op.h" \
+  "operator for TABLE_PROJECT; payload is an instance of the cvc5::internal::TableProjectOp class"
+
+parameterized TABLE_PROJECT TABLE_PROJECT_OP 1 "table projection"
+
 typerule TABLE_PRODUCT              ::cvc5::internal::theory::bags::TableProductTypeRule
+typerule TABLE_PROJECT_OP           "SimpleTypeRule<RBuiltinOperator>"
+typerule TABLE_PROJECT              ::cvc5::internal::theory::bags::TableProjectTypeRule
 
 endtheory
diff --git a/src/theory/bags/table_project_op.cpp b/src/theory/bags/table_project_op.cpp
new file mode 100644 (file)
index 0000000..426753d
--- /dev/null
@@ -0,0 +1,25 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * A class for TableProjectOp operator.
+ */
+
+#include "table_project_op.h"
+
+namespace cvc5::internal {
+
+TableProjectOp::TableProjectOp(std::vector<uint32_t> indices)
+    : ProjectOp(std::move(indices))
+{
+}
+
+}  // namespace cvc5::internal
diff --git a/src/theory/bags/table_project_op.h b/src/theory/bags/table_project_op.h
new file mode 100644 (file)
index 0000000..a061537
--- /dev/null
@@ -0,0 +1,45 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * A class for TableProjectOp operator.
+ */
+
+#include "cvc5_public.h"
+
+#ifndef CVC5__TABLE_PROJECT_OP_H
+#define CVC5__TABLE_PROJECT_OP_H
+
+#include "theory/datatypes/tuple_project_op.h"
+
+namespace cvc5::internal {
+
+/**
+ * The class is an operator for kind project used to project elements in a
+ * table. It stores the indices of projected elements
+ */
+class TableProjectOp : public ProjectOp
+{
+ public:
+  explicit TableProjectOp(std::vector<uint32_t> indices);
+  TableProjectOp(const TableProjectOp& op) = default;
+}; /* class TableProjectOp */
+
+/**
+ * Hash function for the TupleProjectOpHashFunction objects.
+ */
+struct TableProjectOpHashFunction : public ProjectOpHashFunction
+{
+}; /* struct TupleProjectOpHashFunction */
+
+}  // namespace cvc5::internal
+
+#endif /* CVC5__TABLE_PROJECT_OP_H */
index 4307dcbe345217aee0dc2fb2838ac3b193177a47..92ea5eccaa6cf0c3d8557247a92356d5b8d2a480 100644 (file)
@@ -80,6 +80,8 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(BAG_CARD);
   d_equalityEngine->addFunctionKind(BAG_FROM_SET);
   d_equalityEngine->addFunctionKind(BAG_TO_SET);
+  d_equalityEngine->addFunctionKind(TABLE_PRODUCT);
+  d_equalityEngine->addFunctionKind(TABLE_PROJECT);
 }
 
 TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
@@ -453,6 +455,7 @@ void TheoryBags::preRegisterTerm(TNode n)
     case BAG_FROM_SET:
     case BAG_TO_SET:
     case BAG_IS_SINGLETON:
+    case TABLE_PROJECT:
     {
       std::stringstream ss;
       ss << "Term of kind " << n.getKind() << " is not supported yet";
index fd47f006bb9a12d5781dfcf94b500b9d8734c1ea..ef2a5a35002bb59f1fd0918a45e4aee2bfc47cea 100644 (file)
 #include <sstream>
 
 #include "base/check.h"
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
 #include "expr/emptybag.h"
+#include "table_project_op.h"
 #include "theory/bags/bag_make_op.h"
 #include "theory/bags/bags_utils.h"
+#include "theory/datatypes/tuple_project_op.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "util/cardinality.h"
 #include "util/rational.h"
 
@@ -494,6 +499,61 @@ TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode TableProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check)
+{
+  Assert(n.getKind() == kind::TABLE_PROJECT && n.hasOperator()
+         && n.getOperator().getKind() == kind::TABLE_PROJECT_OP);
+  TableProjectOp op = n.getOperator().getConst<TableProjectOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+  TypeNode bagType = n[0].getType(check);
+  if (check)
+  {
+    if (n.getNumChildren() != 1)
+    {
+      std::stringstream ss;
+      ss << "operands in term " << n << " are " << n.getNumChildren()
+         << ", but TABLE_PROJECT expects 1 operand.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    if (!bagType.isBag())
+    {
+      std::stringstream ss;
+      ss << "TABLE_PROJECT operator expects a table. Found '" << n[0]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = bagType.getBagElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "TABLE_PROJECT operator expects a table. Found '" << n[0]
+         << "' of type '" << bagType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    // make sure all indices are less than the length of the tuple type
+    DType dType = tupleType.getDType();
+    DTypeConstructor constructor = dType[0];
+    size_t numArgs = constructor.getNumArgs();
+    for (uint32_t index : indices)
+    {
+      std::stringstream ss;
+      if (index >= numArgs)
+      {
+        ss << "Index " << index << " in term " << n << " is >= " << numArgs
+           << " which is the number of columns in " << n[0] << ".";
+        throw TypeCheckingExceptionPrivate(n, ss.str());
+      }
+    }
+  }
+  TypeNode tupleType = bagType.getBagElementType();
+  TypeNode retTupleType =
+      datatypes::TupleUtils::getTupleProjectionType(indices, tupleType);
+  return nm->mkBagType(retTupleType);
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;
index a80523415e035331e5696b4af4eb1a82a5dc454d..54329b405de0ff2e1069aaa95462a1601c91be95 100644 (file)
@@ -168,6 +168,17 @@ struct TableProductTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagFoldTypeRule */
 
+/**
+ * Table project is indexed by a list of indices (n_1, ..., n_m). It ensures
+ * that the argument is a bag of tuples whose arity k is greater than each n_i
+ * for i = 1, ..., m. If the argument is of type (Bag (Tuple T_1 ... T_k)), then
+ * the returned type is (Bag (Tuple T_{n_1} ... T_{n_m})).
+ */
+struct TableProjectTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index fccac800d7e6a6a27f3552f652f5afc0c05941e8..307e7cb619659ae1c47e9d38e44a227ee3f4b37a 100644 (file)
@@ -26,6 +26,7 @@
 #include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/datatypes/tuple_project_op.h"
+#include "tuple_utils.h"
 #include "util/rational.h"
 #include "util/uninterpreted_sort_value.h"
 
@@ -165,29 +166,11 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
     // where each i_j is less than the length of t
 
     Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl;
+
     TupleProjectOp op = in.getOperator().getConst<TupleProjectOp>();
     std::vector<uint32_t> indices = op.getIndices();
     Node tuple = in[0];
-    std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
-    std::vector<TypeNode> types;
-    std::vector<Node> elements;
-    for (uint32_t index : indices)
-    {
-      TypeNode type = tupleTypes[index];
-      types.push_back(type);
-    }
-    TypeNode projectType = nm->mkTupleType(types);
-    const DType& dt = projectType.getDType();
-    elements.push_back(dt[0].getConstructor());
-    const DType& tupleDType = tuple.getType().getDType();
-    const DTypeConstructor& constructor = tupleDType[0];
-    for (uint32_t index : indices)
-    {
-      Node selector = constructor[index].getSelector();
-      Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
-      elements.push_back(element);
-    }
-    Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
+    Node ret = TupleUtils::getTupleProjection(indices, tuple);
 
     Trace("dt-rewrite-project")
         << "Rewrite project: " << in << " ... " << ret << std::endl;
index 6f27fdce806f60dffce32021768532e6f0b95c42..f4450a57d827b97c09b9b5be21597de7e70b0776 100644 (file)
@@ -24,6 +24,7 @@
 #include "expr/type_matcher.h"
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/datatypes/tuple_project_op.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "util/rational.h"
 
 namespace cvc5::internal {
@@ -560,14 +561,7 @@ TypeNode TupleProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check)
     }
   }
   TypeNode tupleType = n[0].getType(check);
-  std::vector<TypeNode> types;
-  DType dType = tupleType.getDType();
-  DTypeConstructor constructor = dType[0];
-  for (uint32_t index : indices)
-  {
-    types.push_back(constructor.getArgType(index));
-  }
-  return nm->mkTupleType(types);
+  return TupleUtils::getTupleProjectionType(indices, tupleType);
 }
 
 TypeNode CodatatypeBoundVariableTypeRule::computeType(NodeManager* nodeManager,
index f5d77d1ac73520b168e0740bd79fc6e3d6b1dbff..0696edc4eb1e34cd215426b389e1f1405c5ac549 100644 (file)
@@ -161,7 +161,7 @@ class MatchBindCaseTypeRule
 
 /**
  * Tuple project is indexed by a list of indices (n_1, ..., n_m). It ensures
- * that the argument is a tuple whose arity k is greater that each n_i for
+ * that the argument is a tuple whose arity k is greater than each n_i for
  * i = 1, ..., m. If the argument is of type (Tuple T_1 ... T_k), then the
  * returned type is (Tuple T_{n_1} ... T_{n_m}).
  */
index 335a031675ef0f653be5e18d47eaf8473dfe6c6a..28027528f1ea7d10afd443846458dffe41ef0596 100644 (file)
@@ -10,7 +10,7 @@
  * directory for licensing information.
  * ****************************************************************************
  *
- * A class for TupleProjectOp operator.
+ * A class for ProjectOp operator.
  */
 
 #include "tuple_project_op.h"
@@ -21,7 +21,7 @@
 
 namespace cvc5::internal {
 
-std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op)
+std::ostream& operator<<(std::ostream& out, const ProjectOp& op)
 {
   for (const uint32_t& index : op.getIndices())
   {
@@ -30,7 +30,7 @@ std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op)
   return out;
 }
 
-size_t TupleProjectOpHashFunction::operator()(const TupleProjectOp& op) const
+size_t ProjectOpHashFunction::operator()(const ProjectOp& op) const
 {
   // we expect most tuples to have length < 10.
   // Therefore we can implement a simple hash function
@@ -42,16 +42,21 @@ size_t TupleProjectOpHashFunction::operator()(const TupleProjectOp& op) const
   return hash;
 }
 
-TupleProjectOp::TupleProjectOp(std::vector<uint32_t> indices)
+ProjectOp::ProjectOp(std::vector<uint32_t> indices)
     : d_indices(std::move(indices))
 {
 }
 
-const std::vector<uint32_t>& TupleProjectOp::getIndices() const { return d_indices; }
+const std::vector<uint32_t>& ProjectOp::getIndices() const { return d_indices; }
 
-bool TupleProjectOp::operator==(const TupleProjectOp& op) const
+bool ProjectOp::operator==(const ProjectOp& op) const
 {
   return d_indices == op.d_indices;
 }
 
+TupleProjectOp::TupleProjectOp(std::vector<uint32_t> indices)
+    : ProjectOp(std::move(indices))
+{
+}
+
 }  // namespace cvc5::internal
index 269c38b171d182bdd589f4cb3bf7d332bc81d404..7d1b46ff958ba98253c954a8db54565d96e3d58b 100644 (file)
@@ -26,32 +26,49 @@ namespace cvc5::internal {
 class TypeNode;
 
 /**
- * The class is an operator for kind project used to project elements in a tuple
- * It stores the indices of projected elements
+ * base class for TupleProjectOp, TupleProjectOp
  */
-class TupleProjectOp
+class ProjectOp
 {
  public:
-  explicit TupleProjectOp(std::vector<uint32_t> indices);
-  TupleProjectOp(const TupleProjectOp& op) = default;
+  explicit ProjectOp(std::vector<uint32_t> indices);
+  ProjectOp(const ProjectOp& op) = default;
 
   /** return the indices of the projection */
   const std::vector<uint32_t>& getIndices() const;
 
-  bool operator==(const TupleProjectOp& op) const;
+  bool operator==(const ProjectOp& op) const;
 
  private:
   std::vector<uint32_t> d_indices;
-}; /* class TupleProjectOp */
+}; /* class ProjectOp */
+
+std::ostream& operator<<(std::ostream& out, const ProjectOp& op);
+
+/**
+ * Hash function for the ProjectOpHashFunction objects.
+ */
+struct ProjectOpHashFunction
+{
+  size_t operator()(const ProjectOp& op) const;
+}; /* struct ProjectOpHashFunction */
 
-std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op);
+/**
+ * The class is an operator for kind project used to project elements in a
+ * table. It stores the indices of projected elements
+ */
+class TupleProjectOp : public ProjectOp
+{
+ public:
+  explicit TupleProjectOp(std::vector<uint32_t> indices);
+  TupleProjectOp(const TupleProjectOp& op) = default;
+}; /* class TupleProjectOp */
 
 /**
  * Hash function for the TupleProjectOpHashFunction objects.
  */
-struct TupleProjectOpHashFunction
+struct TupleProjectOpHashFunction : public ProjectOpHashFunction
 {
-  size_t operator()(const TupleProjectOp& op) const;
 }; /* struct TupleProjectOpHashFunction */
 
 }  // namespace cvc5::internal
index 87114d9b141eba8009215a2c5b7bf8f014c1f177..05528a6442ed73d39af0f90ccb9d33b27f9869a1 100644 (file)
@@ -36,6 +36,48 @@ Node TupleUtils::nthElementOfTuple(Node tuple, int n_th)
       APPLY_SELECTOR, dt[0].getSelectorInternal(tn, n_th), tuple);
 }
 
+Node TupleUtils::getTupleProjection(const std::vector<uint32_t>& indices,
+                                    Node tuple)
+{
+  std::vector<TypeNode> tupleTypes = tuple.getType().getTupleTypes();
+  std::vector<TypeNode> types;
+  std::vector<Node> elements;
+  for (uint32_t index : indices)
+  {
+    TypeNode type = tupleTypes[index];
+    types.push_back(type);
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode projectType = nm->mkTupleType(types);
+  const DType& dt = projectType.getDType();
+  elements.push_back(dt[0].getConstructor());
+  const DType& tupleDType = tuple.getType().getDType();
+  const DTypeConstructor& constructor = tupleDType[0];
+  for (uint32_t index : indices)
+  {
+    Node selector = constructor[index].getSelector();
+    Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple);
+    elements.push_back(element);
+  }
+  Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements);
+  return ret;
+}
+
+TypeNode TupleUtils::getTupleProjectionType(
+    const std::vector<uint32_t>& indices, TypeNode tupleType)
+{
+  std::vector<TypeNode> types;
+  DType dType = tupleType.getDType();
+  DTypeConstructor constructor = dType[0];
+  for (uint32_t index : indices)
+  {
+    types.push_back(constructor.getArgType(index));
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode retTupleType = nm->mkTupleType(types);
+  return retTupleType;
+}
+
 std::vector<Node> TupleUtils::getTupleElements(Node tuple)
 {
   Assert(tuple.getType().isTuple());
index f6651c50f67fe702bfcd90d8eb61878d1c6e9f02..04112139737ef72098fbb19d399cd6c13413c2d2 100644 (file)
@@ -33,6 +33,22 @@ class TupleUtils
    */
   static Node nthElementOfTuple(Node tuple, int n_th);
 
+  /**
+   * @param indices a list of indices for projected elements
+   * @param tuple a node of tuple type
+   * @return the projection of the tuple with the specified indices
+   */
+  static Node getTupleProjection(const std::vector<uint32_t>& indices,
+                                 Node tuple);
+
+  /**
+   * @param indices a list of indices for projected elements
+   * @param tupleType the type of the original tuple
+   * @return the type of the projected tuple
+   */
+  static TypeNode getTupleProjectionType(const std::vector<uint32_t>& indices,
+                                         TypeNode tupleType);
+
   /**
    * @param tuple a tuple node of the form (tuple a_1 ... a_n)
    * @return the vector [a_1, ... a_n]
index 871fe3b66781acb552a7f6264361a4e94d995d02..dd4039cf01b4209799b0f5709deaa58b9b63ba66 100644 (file)
@@ -1798,6 +1798,7 @@ set(regress_1_tests
   regress1/bags/proj-issue497.smt2
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
+  regress1/bags/table_project1.smt2
   regress1/bags/union_disjoint.smt2
   regress1/bags/union_max1.smt2
   regress1/bags/union_max2.smt2
diff --git a/test/regress/cli/regress1/bags/table_project1.smt2 b/test/regress/cli/regress1/bags/table_project1.smt2
new file mode 100644 (file)
index 0000000..882cf48
--- /dev/null
@@ -0,0 +1,21 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+
+(declare-fun A () (Table String Int String Bool))
+(declare-fun B () (Table Int Bool String String))
+(declare-fun C () (Table String String))
+(declare-fun D () Table)
+
+(assert
+ (= A
+    (bag.union_disjoint
+     (bag (tuple "x" 0 "y" false) 5)
+     (bag (tuple "x" 1 "z" true) 10))))
+
+; (bag.union_disjoint (bag (tuple 0 false "x" "y") 5) (bag (tuple 1 true "x" "z") 10)))
+(assert (= B ((_ table.project 1 3 0 2) A)))
+; (bag (tuple "x" "x") 15)
+(assert (= C ((_ table.project 0 0) A)))
+; (bag tuple 15)
+(assert (= D (table.project A)))
+(check-sat)