Add table.product operator (#8020)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 3 Feb 2022 15:55:16 +0000 (09:55 -0600)
committerGitHub <noreply@github.com>
Thu, 3 Feb 2022 15:55:16 +0000 (15:55 +0000)
30 files changed:
src/CMakeLists.txt
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/kinds
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/datatypes/tuple_utils.cpp [new file with mode: 0644]
src/theory/datatypes/tuple_utils.h [new file with mode: 0644]
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/sets/rels_utils.cpp [new file with mode: 0644]
src/theory/sets/rels_utils.h
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rewriter.cpp
test/regress/CMakeLists.txt
test/regress/regress1/bags/product1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/product2.smt2 [new file with mode: 0644]
test/regress/regress1/bags/product3.smt2 [new file with mode: 0644]

index 226f4632d48b9beb5a439f1c7c43ace31f825a07..87ba2bb94cfa6dcdb73fbe4a64ddcf000f25a618 100644 (file)
@@ -661,6 +661,8 @@ libcvc5_add_sources(
   theory/datatypes/theory_datatypes_utils.h
   theory/datatypes/tuple_project_op.cpp
   theory/datatypes/tuple_project_op.h
+  theory/datatypes/tuple_utils.cpp
+  theory/datatypes/tuple_utils.h
   theory/datatypes/type_enumerator.cpp
   theory/datatypes/type_enumerator.h
   theory/decision_manager.cpp
@@ -997,6 +999,7 @@ libcvc5_add_sources(
   theory/sets/inference_manager.cpp
   theory/sets/inference_manager.h
   theory/sets/normal_form.h
+  theory/sets/rels_utils.cpp
   theory/sets/rels_utils.h
   theory/sets/singleton_op.cpp
   theory/sets/singleton_op.h
index d46b8a971650bfcd1cac89c26413c58c547e8aeb..54174aec4c7738276c6da7096c7c6c1a5e0e733f 100644 (file)
@@ -314,6 +314,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
     {BAG_MAP, cvc5::Kind::BAG_MAP},
     {BAG_FILTER, cvc5::Kind::BAG_FILTER},
     {BAG_FOLD, cvc5::Kind::BAG_FOLD},
+    {TABLE_PRODUCT, cvc5::Kind::TABLE_PRODUCT},
     /* Strings ------------------------------------------------------------- */
     {STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
     {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP},
@@ -627,6 +628,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::BAG_MAP, BAG_MAP},
         {cvc5::Kind::BAG_FILTER, BAG_FILTER},
         {cvc5::Kind::BAG_FOLD, BAG_FOLD},
+        {cvc5::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
         /* Strings --------------------------------------------------------- */
         {cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
         {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
index 1609fb22146c220d3e3fe7219e51fb69b323c346..112b53eb7124e82ae9a69b232a00bafe4aea7fa5 100644 (file)
@@ -2572,6 +2572,17 @@ enum Kind : int32_t
    *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
    */
   BAG_FOLD,
+  /**
+   * Table cross product.
+   *
+   * Parameters:
+   *   - 1..2: Terms of bag sort
+   *
+   * Create with:
+   *   - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) const`
+   *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+   */
+  TABLE_PRODUCT,
 
   /* Strings --------------------------------------------------------------- */
 
index 1c3ea84df25b5ba4a5c2b2bce44bd3838ab92af7..3352bde1ba065dbec770db29abf950cb16eeca09 100644 (file)
@@ -629,6 +629,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::BAG_MAP, "bag.map");
     addOperator(api::BAG_FILTER, "bag.filter");
     addOperator(api::BAG_FOLD, "bag.fold");
+    addOperator(api::TABLE_PRODUCT, "table.product");
   }
   if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
     defineType("String", d_solver->getStringSort(), true, true);
index dd74f00719dd8a112ea325ce0076dbce7337c2cd..420c176f71a26a02842d0f19afbebcae5911279e 100644 (file)
@@ -1125,6 +1125,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_MAP: return "bag.map";
   case kind::BAG_FILTER: return "bag.filter";
   case kind::BAG_FOLD: return "bag.fold";
+  case kind::TABLE_PRODUCT: return "table.product";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
index ed4b501f3564c270028acffadf286a8797d802b5..dbd2bbc2911ce26196bbc3aefe5ca7e39bdea55c 100644 (file)
@@ -78,6 +78,7 @@ void BagSolver::checkBasicOperations()
         case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
         case kind::BAG_FILTER: checkFilter(n); break;
         case kind::BAG_MAP: checkMap(n); break;
+        case kind::TABLE_PRODUCT: checkProduct(n); break;
         default: break;
       }
       it++;
@@ -303,6 +304,29 @@ void BagSolver::checkFilter(Node n)
   }
 }
 
+void BagSolver::checkProduct(Node n)
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+  const set<Node>& elementsA = d_state.getElements(n[0]);
+  const set<Node>& elementsB = d_state.getElements(n[1]);
+  for (const Node& e1 : elementsA)
+  {
+    for (const Node& e2 : elementsB)
+    {
+      InferInfo i = d_ig.productUp(
+          n, d_state.getRepresentative(e1), d_state.getRepresentative(e2));
+      d_im.lemmaTheoryInference(&i);
+    }
+  }
+
+  std::set<Node> elements = d_state.getElements(n);
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.productDown(n, d_state.getRepresentative(e));
+    d_im.lemmaTheoryInference(&i);
+  }
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index fca72b22e8046e9265ad237293e0d7afd66d11cb..eb578aafd7b527f7f06b0975cd355a1634b3914f 100644 (file)
@@ -98,6 +98,8 @@ class BagSolver : protected EnvObj
   void checkMap(Node n);
   /** apply inference rules for filter operator */
   void checkFilter(Node n);
+  /** apply inference rules for product operator */
+  void checkProduct(Node n);
 
   /** The solver state object */
   SolverState& d_state;
index 396e33557fdc88ba07bf964a1808e4b24eceb7b7..031910cdd2a4c2a5367fb818d29c5b131db6a33a 100644 (file)
@@ -92,6 +92,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_MAP: response = postRewriteMap(n); break;
       case BAG_FILTER: response = postRewriteFilter(n); break;
       case BAG_FOLD: response = postRewriteFold(n); break;
+      case TABLE_PRODUCT: response = postRewriteProduct(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
   }
@@ -654,6 +655,20 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
+
+BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+  TypeNode tableType = n.getType();
+  Node empty = d_nm->mkConst(EmptyBag(tableType));
+  if (n[0].getKind() == BAG_EMPTY || n[1].getKind() == BAG_EMPTY)
+  {
+    return BagsRewriteResponse(empty, Rewrite::PRODUCT_EMPTY);
+  }
+
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 3e5b69a1c62d527598c98d2fda8d3563711ce96d..f05766c5344010b1895eea49e59f26f766a39632 100644 (file)
@@ -247,6 +247,15 @@ class BagsRewriter : public TheoryRewriter
    *  where f: T1 -> T2 -> T2
    */
   BagsRewriteResponse postRewriteFold(const TNode& n) const;
+  /**
+   *  rewrites for n include:
+   *  - (bag.product A (as bag.empty T2)) = (as bag.empty T)
+   *  - (bag.product (as bag.empty T2)) = (f t ... (f t (f t x))) n times, where n > 0
+   *  - (bag.fold f t (bag.union_disjoint A B)) =
+   *       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+   *  where f: T1 -> T2 -> T2
+   */
+  BagsRewriteResponse postRewriteProduct(const TNode& n)const;
 
  private:
   /** Reference to the rewriter statistics. */
index 39987ce9d5dd9fe4bdba2820d5f215a1e823c91f..6514d8d3fed4ec94d6d132a0ffca37bc8b0b811a 100644 (file)
  */
 #include "bags_utils.h"
 
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
 #include "expr/emptybag.h"
 #include "smt/logic_exception.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
 #include "theory/type_enumerator.h"
 #include "util/rational.h"
 
 using namespace cvc5::kind;
+using namespace cvc5::theory::datatypes;
 
 namespace cvc5 {
 namespace theory {
@@ -136,6 +140,7 @@ Node BagsUtils::evaluate(TNode n)
     case BAG_MAP: return evaluateBagMap(n);
     case BAG_FILTER: return evaluateBagFilter(n);
     case BAG_FOLD: return evaluateBagFold(n);
+    case TABLE_PRODUCT: return evaluateProduct(n);
     default: break;
   }
   Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
@@ -778,6 +783,52 @@ Node BagsUtils::evaluateBagFold(TNode n)
   return ret;
 }
 
+Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+  Node A = n[0];
+  Node B = n[1];
+  TypeNode typeA = A.getType().getBagElementType();
+  TypeNode typeB = B.getType().getBagElementType();
+  Assert(e1.getType().isSubtypeOf(typeA));
+  Assert(e2.getType().isSubtypeOf(typeB));
+
+  TypeNode productTupleType = n.getType().getBagElementType();
+  Node tuple = TupleUtils::concatTuples(productTupleType, e1, e2);
+  return tuple;
+}
+
+Node BagsUtils::evaluateProduct(TNode n)
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+
+  // Examples
+  // --------
+  //
+  // - (table.product (bag (tuple "a") 4) (bag (tuple true) 5)) =
+  //     (bag (tuple "a" true) 20
+
+  Node A = n[0];
+  Node B = n[1];
+
+  std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
+  std::map<Node, Rational> elementsB = BagsUtils::getBagElements(B);
+
+  std::map<Node, Rational> elements;
+
+  for (const auto& [a, countA] : elementsA)
+  {
+    for (const auto& [b, countB] : elementsB)
+    {
+      Node element = constructProductTuple(n, a, b);
+      elements[element] = countA * countB;
+    }
+  }
+
+  Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
+  return ret;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 61473a023a99f601384b74cd63002bb9f6a6c4af..3b6311ded60b5c4d90d3fe5127edc61596e9e08e 100644 (file)
@@ -17,8 +17,8 @@
 
 #include "cvc5_private.h"
 
-#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H
-#define CVC5__THEORY__BAGS__NORMAL_FORM_H
+#ifndef CVC5__THEORY__BAGS__UTILS_H
+#define CVC5__THEORY__BAGS__UTILS_H
 
 namespace cvc5 {
 namespace theory {
@@ -94,6 +94,21 @@ class BagsUtils
    */
   static Node evaluateBagFilter(TNode n);
 
+  /**
+   * @param n of the form (table.product A B) where A , B of types (Bag T1),
+   * (Bag T2) respectively.
+   * @param e1 a tuple of type T1 of the form (tuple a1 ... an)
+   * @param e2 a tuple of type T2 of the form (tuple b1 ... bn)
+   * @return  (tuple a1 ... an b1 ... bn)
+   */
+  static Node constructProductTuple(TNode n, TNode e1, TNode e2);
+
+  /**
+   * @param n of the form (table.product A B) where A, B are constants
+   * @return the evaluation of the cross product of A B
+   */
+  static Node evaluateProduct(TNode n);
+
  private:
   /**
    * a high order helper function that return a constant bag that is the result
@@ -220,4 +235,4 @@ class BagsUtils
 }  // namespace theory
 }  // namespace cvc5
 
-#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */
+#endif /* CVC5__THEORY__BAGS__UTILS_H */
index 3cc2936fbcbfeb6e94d1681da9a43657acd648bd..aa5bf74d8cccefd984b6353b5d32f5dbf550d11e 100644 (file)
 
 #include "expr/attribute.h"
 #include "expr/bound_var_manager.h"
+#include "expr/dtype_cons.h"
 #include "expr/emptybag.h"
 #include "expr/skolem_manager.h"
+#include "theory/bags/bags_utils.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/solver_state.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/quantifiers/fmf/bounded_integers.h"
 #include "theory/uf/equality_engine.h"
 #include "util/rational.h"
 
 using namespace cvc5::kind;
+using namespace cvc5::theory::datatypes;
 
 namespace cvc5 {
 namespace theory {
@@ -563,6 +567,60 @@ InferInfo InferenceGenerator::filterUpwards(Node n, Node e)
   return inferInfo;
 }
 
+InferInfo InferenceGenerator::productUp(Node n, Node e1, Node e2)
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+  Node A = n[0];
+  Node B = n[1];
+  Node tuple = BagsUtils::constructProductTuple(n, e1, e2);
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_UP);
+
+  Node countA = getMultiplicityTerm(e1, A);
+  Node countB = getMultiplicityTerm(e2, B);
+
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(tuple, skolem);
+
+  Node multiply = d_nm->mkNode(MULT, countA, countB);
+  inferInfo.d_conclusion = count.eqNode(multiply);
+
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::productDown(Node n, Node e)
+{
+  Assert(n.getKind() == TABLE_PRODUCT);
+  Assert(e.getType().isSubtypeOf(n.getType().getBagElementType()));
+
+  Node A = n[0];
+  Node B = n[1];
+
+  TypeNode tupleBType = B.getType().getBagElementType();
+  TypeNode tupleAType = A.getType().getBagElementType();
+  size_t tupleALength = tupleAType.getTupleLength();
+  size_t productTupleLength = n.getType().getBagElementType().getTupleLength();
+
+  std::vector<Node> elements = TupleUtils::getTupleElements(e);
+  Node a = TupleUtils::constructTupleFromElements(
+      tupleAType, elements, 0, tupleALength - 1);
+  Node b = TupleUtils::constructTupleFromElements(
+      tupleBType, elements, tupleALength, productTupleLength - 1);
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_DOWN);
+
+  Node countA = getMultiplicityTerm(a, A);
+  Node countB = getMultiplicityTerm(b, B);
+
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
+
+  Node multiply = d_nm->mkNode(MULT, countA, countB);
+  inferInfo.d_conclusion = count.eqNode(multiply);
+
+  return inferInfo;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 3d74dbaa23148b97c72cde5816d55697ef721c15..ed6122356fa7f4d77e9066e3ffe7cae499113df0 100644 (file)
@@ -271,7 +271,7 @@ class InferenceGenerator
    *   (bag.member e skolem)
    *   (and
    *     (p e)
-   *     (= (bag.count e skolem) (bag.count A)))
+   *     (= (bag.count e skolem) (bag.count A)))
    * where skolem is a variable equals (bag.filter p A)
    */
   InferInfo filterDownwards(Node n, Node e);
@@ -290,6 +290,29 @@ class InferenceGenerator
    */
   InferInfo filterUpwards(Node n, Node e);
 
+  /**
+   * @param n is a (table.product A B) where A, B are bags of tuples
+   * @param e1 an element of the form (tuple a1 ... am)
+   * @param e2 an element of the form (tuple b1 ... bn)
+   * @return  an inference that represents the following
+   * (=
+   *   (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *   (* (bag.count e1 A) (bag.count e2 B)))
+   * where skolem is a variable equals (bag.product A B)
+   */
+  InferInfo productUp(Node n, Node e1, Node e2);
+
+  /**
+   * @param n is a (table.product A B) where A, B are bags of tuples
+   * @param e an element of the form (tuple a1 ... am b1 ... bn)
+   * @return an inference that represents the following
+   * (=
+   *   (bag.count (tuple a1 ... am b1 ... bn) skolem)
+   *   (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B)))
+   * where skolem is a variable equals (bag.product A B)
+   */
+  InferInfo productDown(Node n, Node e);
+
   /**
    * @param element of type T
    * @param bag of type (bag T)
index 7d995dd7bd0d64c40c84ae3ff0a87deaf323f241..345b71e9b4d6819621cfc752755cda385ba776b2 100644 (file)
@@ -113,4 +113,10 @@ typerule BAG_FOLD                ::cvc5::theory::bags::BagFoldTypeRule
 construle BAG_UNION_DISJOINT     ::cvc5::theory::bags::BinaryOperatorTypeRule
 construle BAG_MAKE               ::cvc5::theory::bags::BagMakeTypeRule
 
+
+# bag.product operator returns the cross product of two tables
+operator TABLE_PRODUCT             2 "table cross product"
+
+typerule TABLE_PRODUCT              ::cvc5::theory::bags::TableProductTypeRule
+
 endtheory
index 9bd0c3a86ca7afe27786448d10a23b4ff2128b29..576f1245cbcfbbbaa9e52fd83858d79f04f09a2e 100644 (file)
@@ -56,6 +56,7 @@ const char* toString(Rewrite r)
     case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE";
     case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT";
     case Rewrite::MEMBER: return "MEMBER";
+    case Rewrite::PRODUCT_EMPTY: return "PRODUCT_EMPTY";
     case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
     case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
     case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT";
index e1ef38c4b14e9e695282cc38a8931460f3f1727c..e7f2113f9d9ac1586a7313d89815863d75ab638c 100644 (file)
@@ -60,6 +60,7 @@ enum class Rewrite : uint32_t
   MAP_BAG_MAKE,
   MAP_UNION_DISJOINT,
   MEMBER,
+  PRODUCT_EMPTY,
   REMOVE_FROM_UNION,
   REMOVE_MIN,
   REMOVE_RETURN_LEFT,
index 689b0e208cd2644dfdb15c4a5ccdd059c07af520..b0c79fb1d07bd05e0d427daead177024e99729b8 100644 (file)
@@ -449,6 +449,51 @@ TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
+                                           TNode n,
+                                           bool check)
+{
+  Assert(n.getKind() == kind::TABLE_PRODUCT);
+  Node A = n[0];
+  Node B = n[1];
+  TypeNode typeA = n[0].getType(check);
+  TypeNode typeB = n[1].getType(check);
+
+  if (check && !(typeA.isBag() && typeB.isBag()))
+  {
+    std::stringstream ss;
+    ss << "Operator " << n.getKind() << " expects two bags. "
+       << "Found two terms of types '" << typeA << "' and '" << typeB
+       << "' respectively.";
+    throw TypeCheckingExceptionPrivate(n, ss.str());
+  }
+
+  TypeNode elementAType = typeA.getBagElementType();
+  TypeNode elementBType = typeB.getBagElementType();
+
+  if (check && !(elementAType.isTuple() && elementBType.isTuple()))
+  {
+    std::stringstream ss;
+    ss << "Operator " << n.getKind() << " expects two tables (bags of tuples). "
+       << "Found two terms of types '" << typeA << "' and '" << typeB
+       << "' respectively.";
+    throw TypeCheckingExceptionPrivate(n, ss.str());
+  }
+
+  std::vector<TypeNode> productTupleTypes;
+  std::vector<TypeNode> tupleATypes = elementAType.getTupleTypes();
+  std::vector<TypeNode> tupleBTypes = elementBType.getTupleTypes();
+
+  productTupleTypes.insert(
+      productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end());
+  productTupleTypes.insert(
+      productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end());
+
+  TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes);
+  TypeNode retType = nodeManager->mkBagType(retTupleType);
+  return retType;
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;
index 76c179a62fb9dd9a4062c8ee659f146efe9b82f6..8673f7296896da0d81a32cfec1f7b2486c82ef0c 100644 (file)
@@ -159,6 +159,15 @@ struct BagFoldTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagFoldTypeRule */
 
+/**
+ * Type rule for (table.product A B) to make sure A,B are bags of tuples,
+ * and get the type of the cross product
+ */
+struct TableProductTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp
new file mode 100644 (file)
index 0000000..d691b38
--- /dev/null
@@ -0,0 +1,123 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for data types.
+ */
+
+#include "tuple_utils.h"
+
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace datatypes {
+
+Node TupleUtils::nthElementOfTuple(Node tuple, int n_th)
+{
+  if (tuple.getKind() == APPLY_CONSTRUCTOR)
+  {
+    return tuple[n_th];
+  }
+  TypeNode tn = tuple.getType();
+  const DType& dt = tn.getDType();
+  return NodeManager::currentNM()->mkNode(
+      APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, n_th), tuple);
+}
+
+std::vector<Node> TupleUtils::getTupleElements(Node tuple)
+{
+  Assert(tuple.getType().isTuple());
+  size_t tupleLength = tuple.getType().getTupleLength();
+  std::vector<Node> elements;
+  for (size_t i = 0; i < tupleLength; i++)
+  {
+    elements.push_back(TupleUtils::nthElementOfTuple(tuple, i));
+  }
+  return elements;
+}
+
+std::vector<Node> TupleUtils::getTupleElements(Node tuple1, Node tuple2)
+{
+  std::vector<Node> elements;
+  std::vector<Node> elementsA = getTupleElements(tuple1);
+  size_t tuple1Length = tuple1.getType().getTupleLength();
+  for (size_t i = 0; i < tuple1Length; i++)
+  {
+    elements.push_back(TupleUtils::nthElementOfTuple(tuple1, i));
+  }
+
+  size_t tuple2Length = tuple2.getType().getTupleLength();
+  for (size_t i = 0; i < tuple2Length; i++)
+  {
+    elements.push_back(TupleUtils::nthElementOfTuple(tuple2, i));
+  }
+  return elements;
+}
+
+Node TupleUtils::constructTupleFromElements(TypeNode tupleType,
+                                            const std::vector<Node>& elements,
+                                            size_t start,
+                                            size_t end)
+{
+  std::vector<Node> tupleElements;
+  // add the constructor first
+  Node constructor = tupleType.getDType()[0].getConstructor();
+  tupleElements.push_back(constructor);
+  // add the elements of the tuple
+  for (size_t i = start; i <= end; i++)
+  {
+    tupleElements.push_back(elements[i]);
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  Node tuple = nm->mkNode(APPLY_CONSTRUCTOR, tupleElements);
+  return tuple;
+}
+
+Node TupleUtils::concatTuples(TypeNode tupleType, Node tuple1, Node tuple2)
+{
+  std::vector<Node> tupleElements;
+  // add the constructor first
+  Node constructor = tupleType.getDType()[0].getConstructor();
+  tupleElements.push_back(constructor);
+
+  // add the flattened concatenation of the two tuples e1, e2
+  std::vector<Node> elements = getTupleElements(tuple1, tuple2);
+  tupleElements.insert(tupleElements.end(), elements.begin(), elements.end());
+
+  // construct the returned tuple
+  NodeManager* nm = NodeManager::currentNM();
+  Node tuple = nm->mkNode(APPLY_CONSTRUCTOR, tupleElements);
+  return tuple;
+}
+
+Node TupleUtils::reverseTuple(Node tuple)
+{
+  Assert(tuple.getType().isTuple());
+  std::vector<Node> elements;
+  std::vector<TypeNode> tuple_types = tuple.getType().getTupleTypes();
+  std::reverse(tuple_types.begin(), tuple_types.end());
+  TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types);
+  const DType& dt = tn.getDType();
+  elements.push_back(dt[0].getConstructor());
+  for (int i = tuple_types.size() - 1; i >= 0; --i)
+  {
+    elements.push_back(nthElementOfTuple(tuple, i));
+  }
+  return NodeManager::currentNM()->mkNode(APPLY_CONSTRUCTOR, elements);
+}
+
+}  // namespace datatypes
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/datatypes/tuple_utils.h b/src/theory/datatypes/tuple_utils.h
new file mode 100644 (file)
index 0000000..595052c
--- /dev/null
@@ -0,0 +1,83 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for data types.
+ */
+
+#ifndef CVC5__THEORY__TUPLE__UTILS_H
+#define CVC5__THEORY__TUPLE__UTILS_H
+
+#include "expr/node.h"
+
+namespace cvc5 {
+namespace theory {
+namespace datatypes {
+
+class TupleUtils
+{
+ public:
+  /**
+   * @param tuple a node of tuple type
+   * @param n_th the index of the element to be extracted, and must satisfy the
+   * constraint 0 <= n_th < length of tuple.
+   * @return tuple element at index n_th
+   */
+  static Node nthElementOfTuple(Node tuple, int n_th);
+
+  /**
+   * @param tuple a tuple node of the form (tuple a_1 ... a_n)
+   * @return the vector [a_1, ... a_n]
+   */
+  static std::vector<Node> getTupleElements(Node tuple);
+
+  /**
+   * @param tuple1 a tuple node of the form (tuple a_1 ... a_n)
+   * @param tuple2 a tuple node of the form (tuple b_1 ... b_n)
+   * @return the vector [a_1, ... a_n, b_1, ... b_n]
+   */
+  static std::vector<Node> getTupleElements(Node tuple1, Node tuple2);
+
+  /**
+   * construct a tuple from a list of elements
+   * @param tupleType the type of the returned tuple
+   * @param elements the list of nodes
+   * @param start the index of the first element
+   * @param end the index of the last element
+   * @pre the elements from start to end should match the tuple type
+   * @return a tuple of constructed from elements from start to end
+   */
+  static Node constructTupleFromElements(TypeNode tupleType,
+                                         const std::vector<Node>& elements,
+                                         size_t start,
+                                         size_t end);
+
+  /**
+   * construct a flattened tuple from two tuples
+   * @param tupleType the type of the returned tuple
+   * @param tuple1 a tuple node of the form (tuple a_1 ... a_n)
+   * @param tuple2 a tuple node of the form (tuple b_1 ... b_n)
+   * @pre the elements of tuple1, tuple2 should match the tuple type
+   * @return  (tuple a1 ... an b1 ... bn)
+   */
+  static Node concatTuples(TypeNode tupleType, Node tuple1, Node tuple2);
+
+  /**
+   * @param tuple a tuple node of the form (tuple e_1 ... e_n)
+   * @return the reverse of the argument, i.e., (tuple e_n ... e_1)
+   */
+  static Node reverseTuple(Node tuple);
+};
+}  // namespace datatypes
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__TUPLE__UTILS_H */
index 240d6e29374f3997e6752d689c646546d443f44b..791819f381b94697b1484f710887aac3efdd25e7 100644 (file)
@@ -124,6 +124,8 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP";
     case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
     case InferenceId::BAGS_CARD: return "BAGS_CARD";
+    case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP";
+    case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN";
 
     case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
     case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
index 2fb3ae003b380e207aff2a7f4e3d41ee4fd6d9fd..4301e0d165db1a01e694aa7f475a048f9436bbc9 100644 (file)
@@ -186,6 +186,8 @@ enum class InferenceId
   BAGS_FILTER_UP,
   BAGS_FOLD,
   BAGS_CARD,
+  TABLES_PRODUCT_UP,
+  TABLES_PRODUCT_DOWN,
   // ---------------------------------- end bags theory
 
   // ---------------------------------- bitvector theory
diff --git a/src/theory/sets/rels_utils.cpp b/src/theory/sets/rels_utils.cpp
new file mode 100644 (file)
index 0000000..fdd9e63
--- /dev/null
@@ -0,0 +1,81 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for relations.
+ */
+
+#include "rels_utils.h"
+
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
+#include "theory/datatypes/tuple_utils.h"
+
+using namespace cvc5::theory::datatypes;
+
+namespace cvc5 {
+namespace theory {
+namespace sets {
+
+std::set<Node> RelsUtils::computeTC(const std::set<Node>& members, Node rel)
+{
+  std::set<Node>::iterator mem_it = members.begin();
+  std::map<Node, int> ele_num_map;
+  std::set<Node> tc_rel_mem;
+
+  while (mem_it != members.end())
+  {
+    Node fst = TupleUtils::nthElementOfTuple(*mem_it, 0);
+    Node snd = TupleUtils::nthElementOfTuple(*mem_it, 1);
+    std::set<Node> traversed;
+    traversed.insert(fst);
+    computeTC(rel, members, fst, snd, traversed, tc_rel_mem);
+    mem_it++;
+  }
+  return tc_rel_mem;
+}
+
+void RelsUtils::computeTC(Node rel,
+                          const std::set<Node>& members,
+                          Node a,
+                          Node b,
+                          std::set<Node>& traversed,
+                          std::set<Node>& transitiveClosureMembers)
+{
+  transitiveClosureMembers.insert(constructPair(rel, a, b));
+  if (traversed.find(b) != traversed.end())
+  {
+    return;
+  }
+  traversed.insert(b);
+  std::set<Node>::iterator mem_it = members.begin();
+  while (mem_it != members.end())
+  {
+    Node new_fst = TupleUtils::nthElementOfTuple(*mem_it, 0);
+    Node new_snd = TupleUtils::nthElementOfTuple(*mem_it, 1);
+    if (b == new_fst)
+    {
+      computeTC(rel, members, a, new_snd, traversed, transitiveClosureMembers);
+    }
+    mem_it++;
+  }
+}
+
+Node RelsUtils::constructPair(Node rel, Node a, Node b)
+{
+  const DType& dt = rel.getType().getSetElementType().getDType();
+  return NodeManager::currentNM()->mkNode(
+      kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b);
+}
+
+}  // namespace sets
+}  // namespace theory
+}  // namespace cvc5
index 46eeecd58cad99fb387590c91f25b2f4ca972b06..c070ad1dafbee775083868273909e00573dfb0e9 100644 (file)
  * directory for licensing information.
  * ****************************************************************************
  *
- * Extension to Sets theory.
+ * Utility functions for relations.
  */
 
 #ifndef SRC_THEORY_SETS_RELS_UTILS_H_
 #define SRC_THEORY_SETS_RELS_UTILS_H_
 
-#include "expr/dtype.h"
-#include "expr/dtype_cons.h"
 #include "expr/node.h"
 
 namespace cvc5 {
 namespace theory {
 namespace sets {
 
-class RelsUtils {
+class RelsUtils
+{
+ public:
+  /**
+   * compute the transitive closure of a binary relation
+   * @param members constant nodes of type (Tuple E E) that are known to in the
+   * relation rel
+   * @param rel a binary relation of type (Set (Tuple E E))
+   * @pre all members need to be constants
+   * @return the transitive closure of the relation
+   */
+  static std::set<Node> computeTC(const std::set<Node>& members, Node rel);
 
-public:
+  /**
+   * add all pairs (a, c) to the transitive closures where c is reachable from b
+   * in the transitive relation in a depth first search manner.
+   * @param rel a binary relation of type (Set (Tuple E E))
+   * @param members constant nodes of type (Tuple E E) that are known to be in
+   * the relation rel
+   * @param a a node of type E where (a,b) is an element in the transitive
+   * closure
+   * @param b a node of type E where (a,b) is an element in the transitive
+   * closure
+   * @param traversed the set of members that have been visited so far
+   * @param transitiveClosureMembers members of the transitive closure computed
+   * so far
+   */
+  static void computeTC(Node rel,
+                        const std::set<Node>& members,
+                        Node a,
+                        Node b,
+                        std::set<Node>& traversed,
+                        std::set<Node>& transitiveClosureMembers);
 
-  // Assumption: the input rel_mem contains all constant pairs
-  static std::set< Node > computeTC( std::set< Node > rel_mem, Node rel ) {
-    std::set< Node >::iterator mem_it = rel_mem.begin();
-    std::map< Node, int > ele_num_map;
-    std::set< Node > tc_rel_mem;
-       
-    while( mem_it != rel_mem.end() ) {
-      Node fst = nthElementOfTuple( *mem_it, 0 );
-      Node snd = nthElementOfTuple( *mem_it, 1 );
-      std::set< Node > traversed;
-      traversed.insert(fst);
-      computeTC(rel, rel_mem, fst, snd, traversed, tc_rel_mem);      
-      mem_it++;             
-    }
-    return tc_rel_mem;
-  }
-  
-  static void computeTC( Node rel, std::set< Node >& rel_mem, Node fst, 
-                         Node snd, std::set< Node >& traversed, std::set< Node >& tc_rel_mem ) {    
-    tc_rel_mem.insert(constructPair(rel, fst, snd));
-    if( traversed.find(snd) == traversed.end() ) {
-      traversed.insert(snd);
-    } else {
-      return;
-    }
-
-    std::set< Node >::iterator mem_it = rel_mem.begin();
-    while( mem_it != rel_mem.end() ) {
-      Node new_fst = nthElementOfTuple( *mem_it, 0 );
-      Node new_snd = nthElementOfTuple( *mem_it, 1 );
-      if( snd == new_fst ) {
-        computeTC(rel, rel_mem, fst, new_snd, traversed, tc_rel_mem);
-      }
-      mem_it++; 
-    }  
-  }
-  static Node nthElementOfTuple( Node tuple, int n_th ) {    
-    if( tuple.getKind() == kind::APPLY_CONSTRUCTOR ) {
-      return tuple[n_th];
-    }
-    TypeNode tn = tuple.getType();
-    const DType& dt = tn.getDType();
-    return NodeManager::currentNM()->mkNode(
-        kind::APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, n_th), tuple);
-  } 
-  
-  static Node reverseTuple( Node tuple ) {
-    Assert(tuple.getType().isTuple());
-    std::vector<Node> elements;
-    std::vector<TypeNode> tuple_types = tuple.getType().getTupleTypes();
-    std::reverse( tuple_types.begin(), tuple_types.end() );
-    TypeNode tn = NodeManager::currentNM()->mkTupleType( tuple_types );
-    const DType& dt = tn.getDType();
-    elements.push_back(dt[0].getConstructor());
-    for(int i = tuple_types.size() - 1; i >= 0; --i) {
-      elements.push_back( nthElementOfTuple(tuple, i) );
-    }
-    return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements );
-  }
-  static Node constructPair(Node rel, Node a, Node b) {
-    const DType& dt = rel.getType().getSetElementType().getDType();
-    return NodeManager::currentNM()->mkNode(
-        kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b);
-  }     
-    
+  /**
+   * construct a pair from two elements
+   * @param rel a node of type (Set (Tuple E E))
+   * @param a a node of type E
+   * @param b a node of type E
+   * @return  a tuple (tuple a b)
+   */
+  static Node constructPair(Node rel, Node a, Node b);
 };
 }  // namespace sets
 }  // namespace theory
index 49f8f053adf5d3efdc9bcec501fdd2d86307edc1..d6a52b76e4d0582ffc75df7f8b6649342d3d71a8 100644 (file)
 
 #include "theory/sets/theory_sets_rels.h"
 
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
 #include "expr/skolem_manager.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/theory_sets.h"
 #include "theory/sets/theory_sets_private.h"
 #include "util/rational.h"
 
 using namespace std;
 using namespace cvc5::kind;
+using namespace cvc5::theory::datatypes;
 
 namespace cvc5 {
 namespace theory {
@@ -268,7 +272,7 @@ void TheorySetsRels::check(Theory::Effort level)
           std::vector<TypeNode> tupleTypes = erType.getTupleTypes();
           for (unsigned i = 0, tlen = erType.getTupleLength(); i < tlen; i++)
           {
-            Node element = RelsUtils::nthElementOfTuple(eqc_node, i);
+            Node element = TupleUtils::nthElementOfTuple(eqc_node, i);
             if (!element.isConst())
             {
               makeSharedTerm(element, tupleTypes[i]);
@@ -306,7 +310,7 @@ void TheorySetsRels::check(Theory::Effort level)
     unsigned int min_card = join_image_term[1].getConst<Rational>().getNumerator().getUnsignedInt();
 
     while( mem_rep_it != (*rel_mem_it).second.end() ) {
-      Node fst_mem_rep = RelsUtils::nthElementOfTuple( *mem_rep_it, 0 );
+      Node fst_mem_rep = TupleUtils::nthElementOfTuple( *mem_rep_it, 0 );
 
       if( hasChecked.find( fst_mem_rep ) != hasChecked.end() ) {
         ++mem_rep_it;
@@ -333,12 +337,14 @@ void TheorySetsRels::check(Theory::Effort level)
       std::vector< Node >::iterator mem_rep_exp_it_snd = (*rel_mem_exp_it).second.begin();
 
       while( mem_rep_exp_it_snd != (*rel_mem_exp_it).second.end() ) {
-        Node fst_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 0 );
+        Node fst_element_snd_mem =
+            TupleUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 0 );
 
         if( areEqual( fst_mem_rep,  fst_element_snd_mem ) ) {
           bool notExist = true;
           std::vector< Node >::iterator existing_mem_it = existing_members.begin();
-          Node snd_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 1 );
+          Node snd_element_snd_mem =
+              TupleUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 1 );
 
           while( existing_mem_it != existing_members.end() ) {
             if( areEqual( (*existing_mem_it), snd_element_snd_mem ) ) {
@@ -410,7 +416,7 @@ void TheorySetsRels::check(Theory::Effort level)
     Node reason = exp;
     Node conclusion = d_trueNode;
     std::vector< Node > distinct_skolems;
-    Node fst_mem_element = RelsUtils::nthElementOfTuple( exp[0], 0 );
+    Node fst_mem_element = TupleUtils::nthElementOfTuple( exp[0], 0 );
 
     if( exp[1] != join_image_term ) {
       reason =
@@ -451,8 +457,8 @@ void TheorySetsRels::check(Theory::Effort level)
       d_rel_nodes.insert( iden_term );
     }
     Node reason = exp;
-    Node fst_mem = RelsUtils::nthElementOfTuple( exp[0], 0 );
-    Node snd_mem = RelsUtils::nthElementOfTuple( exp[0], 1 );
+    Node fst_mem = TupleUtils::nthElementOfTuple( exp[0], 0 );
+    Node snd_mem = TupleUtils::nthElementOfTuple( exp[0], 1 );
     const DType& dt = iden_term[0].getType().getSetElementType().getDType();
     Node fact = nm->mkNode(
         SET_MEMBER,
@@ -489,7 +495,7 @@ void TheorySetsRels::check(Theory::Effort level)
 
     while( mem_rep_exp_it != (*rel_mem_exp_it).second.end() ) {
       Node reason = *mem_rep_exp_it;
-      Node fst_exp_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it)[0], 0 );
+      Node fst_exp_mem = TupleUtils::nthElementOfTuple( (*mem_rep_exp_it)[0], 0 );
       Node new_mem = RelsUtils::constructPair( iden_term, fst_exp_mem, fst_exp_mem );
 
       if( (*mem_rep_exp_it)[1] != iden_term_rel ) {
@@ -548,8 +554,8 @@ void TheorySetsRels::check(Theory::Effort level)
 
     // add mem_rep to d_tcrRep_tcGraph
     TC_IT tc_it = d_tcr_tcGraph.find( tc_rel );
-    Node mem_rep_fst = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 0 ) );
-    Node mem_rep_snd = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 1 ) );
+    Node mem_rep_fst = getRepresentative(TupleUtils::nthElementOfTuple( mem_rep, 0 ) );
+    Node mem_rep_snd = getRepresentative(TupleUtils::nthElementOfTuple( mem_rep, 1 ) );
     Node mem_rep_tup = RelsUtils::constructPair( tc_rel, mem_rep_fst, mem_rep_snd );
 
     if( tc_it != d_tcr_tcGraph.end() ) {
@@ -580,8 +586,8 @@ void TheorySetsRels::check(Theory::Effort level)
       exp_map[mem_rep_tup] = exp;
       d_tcr_tcGraph_exps[tc_rel] = exp_map;
     }
-    Node fst_element = RelsUtils::nthElementOfTuple( exp[0], 0 );
-    Node snd_element = RelsUtils::nthElementOfTuple( exp[0], 1 );
+    Node fst_element = TupleUtils::nthElementOfTuple( exp[0], 0 );
+    Node snd_element = TupleUtils::nthElementOfTuple( exp[0], 1 );
     Node sk_1 = d_skCache.mkTypedSkolemCached(fst_element.getType(),
                                               exp[0],
                                               tc_rel[0],
@@ -631,8 +637,8 @@ void TheorySetsRels::check(Theory::Effort level)
     if( tc_it != d_rRep_tcGraph.end() ) {
       bool isReachable = false;
       std::unordered_set<Node> seen;
-      isTCReachable( getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 0) ),
-                     getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 1) ), seen, tc_it->second, isReachable );
+      isTCReachable( getRepresentative(TupleUtils::nthElementOfTuple(mem_rep, 0) ),
+                     getRepresentative(TupleUtils::nthElementOfTuple(mem_rep, 1) ), seen, tc_it->second, isReachable );
       return isReachable;
     }
     return false;
@@ -680,8 +686,8 @@ void TheorySetsRels::check(Theory::Effort level)
 
     for (size_t i = 0, msize = members.size(); i < msize; i++)
     {
-      Node fst_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 0 ));
-      Node snd_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 1 ));
+      Node fst_element_rep = getRepresentative(TupleUtils::nthElementOfTuple( members[i], 0 ));
+      Node snd_element_rep = getRepresentative(TupleUtils::nthElementOfTuple( members[i], 1 ));
       Node tuple_rep = RelsUtils::constructPair( rel_rep, fst_element_rep, snd_element_rep );
       std::map<Node, std::unordered_set<Node> >::iterator rel_tc_graph_it =
           rel_tc_graph.find(fst_element_rep);
@@ -743,12 +749,15 @@ void TheorySetsRels::check(Theory::Effort level)
       std::unordered_set<Node>& seen)
   {
     NodeManager* nm = NodeManager::currentNM();
-    Node tc_mem = RelsUtils::constructPair( tc_rel, RelsUtils::nthElementOfTuple((reasons.front())[0], 0), RelsUtils::nthElementOfTuple((reasons.back())[0], 1) );
+    Node tc_mem = RelsUtils::constructPair( tc_rel,
+        TupleUtils::nthElementOfTuple((reasons.front())[0], 0),
+        TupleUtils::nthElementOfTuple((reasons.back())[0], 1) );
     std::vector< Node > all_reasons( reasons );
 
     for( unsigned int i = 0 ; i < reasons.size()-1; i++ ) {
-      Node fst_element_end = RelsUtils::nthElementOfTuple( reasons[i][0], 1 );
-      Node snd_element_begin = RelsUtils::nthElementOfTuple( reasons[i+1][0], 0 );
+      Node fst_element_end = TupleUtils::nthElementOfTuple( reasons[i][0], 1 );
+      Node snd_element_begin =
+          TupleUtils::nthElementOfTuple( reasons[i+1][0], 0 );
       if( fst_element_end != snd_element_begin ) {
         all_reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, fst_element_end, snd_element_begin) );
       }
@@ -823,12 +832,12 @@ void TheorySetsRels::check(Theory::Effort level)
 
     unsigned int i = 0;
     for(; i < s1_len; ++i) {
-      r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
+      r1_element.push_back(TupleUtils::nthElementOfTuple(mem, i));
     }
     const DType& dt2 = pt_rel[1].getType().getSetElementType().getDType();
     r2_element.push_back(dt2[0].getConstructor());
     for(; i < tup_len; ++i) {
-      r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
+      r2_element.push_back(TupleUtils::nthElementOfTuple(mem, i));
     }
     Node reason   = exp;
     Node mem1     = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
@@ -885,14 +894,14 @@ void TheorySetsRels::check(Theory::Effort level)
     unsigned int i = 0;
     r1_element.push_back(dt1[0].getConstructor());
     for(; i < s1_len-1; ++i) {
-      r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
+      r1_element.push_back(TupleUtils::nthElementOfTuple(mem, i));
     }
     r1_element.push_back(shared_x);
     const DType& dt2 = join_rel[1].getType().getSetElementType().getDType();
     r2_element.push_back(dt2[0].getConstructor());
     r2_element.push_back(shared_x);
     for(; i < tup_len; ++i) {
-      r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i));
+      r2_element.push_back(TupleUtils::nthElementOfTuple(mem, i));
     }
     Node mem1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
     Node mem2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
@@ -966,7 +975,7 @@ void TheorySetsRels::check(Theory::Effort level)
     }
 
     Node reason = exp;
-    Node reversed_mem = RelsUtils::reverseTuple( exp[0] );
+    Node reversed_mem = TupleUtils::reverseTuple( exp[0] );
 
     if( tp_rel != exp[1] ) {
       reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, tp_rel, exp[1]));
@@ -1063,7 +1072,7 @@ void TheorySetsRels::check(Theory::Effort level)
               kind::AND, reason, nm->mkNode(kind::EQUAL, rel[0], exps[i][1]));
         }
         sendInfer(
-            nm->mkNode(SET_MEMBER, RelsUtils::reverseTuple(exps[i][0]), rel),
+            nm->mkNode(SET_MEMBER, TupleUtils::reverseTuple(exps[i][0]), rel),
             InferenceId::SETS_RELS_TRANSPOSE_REV,
             reason);
       }
@@ -1108,9 +1117,8 @@ void TheorySetsRels::check(Theory::Effort level)
         std::vector<Node> reasons;
         if (rk == kind::RELATION_JOIN)
         {
-          Node r1_rmost =
-              RelsUtils::nthElementOfTuple(r1_rep_exps[i][0], r1_tuple_len - 1);
-          Node r2_lmost = RelsUtils::nthElementOfTuple(r2_rep_exps[j][0], 0);
+          Node r1_rmost = TupleUtils::nthElementOfTuple(r1_rep_exps[i][0], r1_tuple_len - 1);
+          Node r2_lmost = TupleUtils::nthElementOfTuple(r2_rep_exps[j][0], 0);
           // Since we require notification r1_rmost and r2_lmost are equal,
           // they must be shared terms of theory of sets. Hence, we make the
           // following calls to makeSharedTerm to ensure this is the case.
@@ -1140,14 +1148,18 @@ void TheorySetsRels::check(Theory::Effort level)
           unsigned int l = 1;
 
           for( ; k < r1_tuple_len - 1; ++k ) {
-            tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
+            tuple_elements.push_back(
+                TupleUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
           }
           if(isProduct) {
-            tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
-            tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ) );
+            tuple_elements.push_back(
+                TupleUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) );
+            tuple_elements.push_back(
+                TupleUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ) );
           }
           for( ; l < r2_tuple_len; ++l ) {
-            tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], l ) );
+            tuple_elements.push_back(
+                TupleUtils::nthElementOfTuple( r2_rep_exps[j][0], l ) );
           }
 
           Node composed_tuple =
@@ -1216,8 +1228,8 @@ void TheorySetsRels::check(Theory::Effort level)
       size_t tlen = atn.getTupleLength();
       for (size_t i = 0; i < tlen; i++)
       {
-        if (!areEqual(RelsUtils::nthElementOfTuple(a, i),
-                      RelsUtils::nthElementOfTuple(b, i)))
+        if (!areEqual(TupleUtils::nthElementOfTuple(a, i),
+                      TupleUtils::nthElementOfTuple(b, i)))
         {
           return false;
         }
@@ -1278,7 +1290,7 @@ void TheorySetsRels::check(Theory::Effort level)
   void TheorySetsRels::computeTupleReps( Node n ) {
     if( d_tuple_reps.find( n ) == d_tuple_reps.end() ){
       for( unsigned i = 0; i < n.getType().getTupleLength(); i++ ){
-        d_tuple_reps[n].push_back( getRepresentative( RelsUtils::nthElementOfTuple(n, i) ) );
+        d_tuple_reps[n].push_back( getRepresentative(TupleUtils::nthElementOfTuple(n, i) ) );
       }
     }
   }
@@ -1295,7 +1307,7 @@ void TheorySetsRels::check(Theory::Effort level)
       std::vector<TypeNode> tupleTypes = n[0].getType().getTupleTypes();
       for (unsigned int i = 0; i < n[0].getType().getTupleLength(); i++)
       {
-        Node element = RelsUtils::nthElementOfTuple(n[0], i);
+        Node element = TupleUtils::nthElementOfTuple(n[0], i);
         makeSharedTerm(element, tupleTypes[i]);
         tuple_elements.push_back(element);
       }
index cc642127c6228b4eaa526e0bdc6424efeb4b2463..6f6b9c38ef5e3607ca97ed0c7115869bcb4891e3 100644 (file)
 #include "theory/sets/theory_sets_rewriter.h"
 
 #include "expr/attribute.h"
+#include "expr/dtype.h"
 #include "expr/dtype_cons.h"
 #include "options/sets_options.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
 #include "theory/sets/rels_utils.h"
 #include "util/rational.h"
 
 using namespace cvc5::kind;
+using namespace cvc5::theory::datatypes;
 
 namespace cvc5 {
 namespace theory {
@@ -350,7 +353,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       std::set<Node>::iterator tuple_it = tuple_set.begin();
 
       while(tuple_it != tuple_set.end()) {
-        new_tuple_set.insert(RelsUtils::reverseTuple(*tuple_it));
+        new_tuple_set.insert(TupleUtils::reverseTuple(*tuple_it));
         ++tuple_it;
       }
       Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType());
@@ -389,7 +392,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
         std::vector<Node> left_tuple;
         left_tuple.push_back(tn.getDType()[0].getConstructor());
         for(int i = 0; i < left_len; i++) {
-          left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i));
+          left_tuple.push_back(TupleUtils::nthElementOfTuple(*left_it,i));
         }
         std::set<Node>::iterator right_it = right.begin();
         int right_len = (*right_it).getType().getTupleLength();
@@ -397,7 +400,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
           Trace("rels-debug") << "Sets::postRewrite processing right_it = " <<  *right_it << std::endl;
           std::vector<Node> right_tuple;
           for(int j = 0; j < right_len; j++) {
-            right_tuple.push_back(RelsUtils::nthElementOfTuple(*right_it,j));
+            right_tuple.push_back(TupleUtils::nthElementOfTuple(*right_it,j));
           }
           std::vector<Node> new_tuple;
           new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
@@ -437,15 +440,16 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
         std::vector<Node> left_tuple;
         left_tuple.push_back(tn.getDType()[0].getConstructor());
         for(int i = 0; i < left_len - 1; i++) {
-          left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i));
+          left_tuple.push_back(TupleUtils::nthElementOfTuple(*left_it,i));
         }
         std::set<Node>::iterator right_it = right.begin();
         int right_len = (*right_it).getType().getTupleLength();
         while(right_it != right.end()) {
-          if(RelsUtils::nthElementOfTuple(*left_it,left_len-1) == RelsUtils::nthElementOfTuple(*right_it,0)) {
+          if(TupleUtils::nthElementOfTuple(*left_it,left_len-1) == TupleUtils::nthElementOfTuple(*right_it,0)) {
             std::vector<Node> right_tuple;
             for(int j = 1; j < right_len; j++) {
-              right_tuple.push_back(RelsUtils::nthElementOfTuple(*right_it,j));
+              right_tuple.push_back(
+                  TupleUtils::nthElementOfTuple(*right_it,j));
             }
             std::vector<Node> new_tuple;
             new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
@@ -508,7 +512,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       std::set<Node>::iterator rel_mems_it = rel_mems.begin();
 
       while( rel_mems_it != rel_mems.end() ) {
-        Node fst_mem = RelsUtils::nthElementOfTuple( *rel_mems_it, 0);
+        Node fst_mem = TupleUtils::nthElementOfTuple( *rel_mems_it, 0);
         iden_rel_mems.insert(RelsUtils::constructPair(node, fst_mem, fst_mem));
         ++rel_mems_it;
       }
@@ -548,7 +552,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       std::set<Node>::iterator rel_mems_it = rel_mems.begin();
 
       while( rel_mems_it != rel_mems.end() ) {
-        Node fst_mem = RelsUtils::nthElementOfTuple( *rel_mems_it, 0);
+        Node fst_mem = TupleUtils::nthElementOfTuple( *rel_mems_it, 0);
         if( has_checked.find( fst_mem ) != has_checked.end() ) {
           ++rel_mems_it;
           continue;
@@ -557,9 +561,10 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
         std::set<Node> existing_mems;
         std::set<Node>::iterator rel_mems_it_snd = rel_mems.begin();
         while( rel_mems_it_snd != rel_mems.end() ) {
-          Node fst_mem_snd = RelsUtils::nthElementOfTuple( *rel_mems_it_snd, 0);
+          Node fst_mem_snd = TupleUtils::nthElementOfTuple( *rel_mems_it_snd, 0);
           if( fst_mem == fst_mem_snd ) {
-            existing_mems.insert( RelsUtils::nthElementOfTuple( *rel_mems_it_snd, 1) );
+            existing_mems.insert(
+                TupleUtils::nthElementOfTuple( *rel_mems_it_snd, 1) );
           }
           ++rel_mems_it_snd;
         }
index c49d5004c8f65ca35a68f051b6cb70b09c76e6f9..67603be826577032bf3f19fcac8b47b2e8411c8a 100644 (file)
@@ -1678,6 +1678,9 @@ set(regress_1_tests
   regress1/bags/murxla1.smt2
   regress1/bags/murxla2.smt2
   regress1/bags/murxla3.smt2
+  regress1/bags/product1.smt2
+  regress1/bags/product2.smt2
+  regress1/bags/product3.smt2
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
   regress1/bags/union_disjoint.smt2
diff --git a/test/regress/regress1/bags/product1.smt2 b/test/regress/regress1/bags/product1.smt2
new file mode 100644 (file)
index 0000000..2f7f090
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag (Tuple String)))
+(declare-fun B () (Bag (Tuple Bool)))
+(declare-fun C () (Bag (Tuple String Bool)))
+(declare-fun x () (Tuple String))
+(declare-fun y () (Tuple Bool))
+(assert (= (bag.count x A) 5))
+(assert (= (bag.count y B) 4))
+(assert (= C (table.product A B)))
+(check-sat)
diff --git a/test/regress/regress1/bags/product2.smt2 b/test/regress/regress1/bags/product2.smt2
new file mode 100644 (file)
index 0000000..ee7d171
--- /dev/null
@@ -0,0 +1,14 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag (Tuple Int Int Int)))
+(declare-fun B () (Bag (Tuple Int Int Int)))
+(declare-fun x () (Tuple Int Int Int))
+(assert (= x (tuple 1 2 3)))
+(declare-fun y () (Tuple Int Int Int))
+(assert (= y (tuple 3 2 1)))
+(declare-fun z () (Tuple Int Int Int Int Int Int))
+(assert (= z (tuple 1 2 3 3 2 1)))
+(assert (bag.member x A))
+(assert (bag.member y B))
+(assert (not (bag.member z (table.product A B))))
+(check-sat)
diff --git a/test/regress/regress1/bags/product3.smt2 b/test/regress/regress1/bags/product3.smt2
new file mode 100644 (file)
index 0000000..8f2e8c3
--- /dev/null
@@ -0,0 +1,21 @@
+(set-logic ALL)
+
+(set-info :status sat)
+
+(declare-fun A () (Bag (Tuple Int Int Int)))
+(declare-fun B () (Bag (Tuple Int Int Int)))
+(declare-fun C () (Bag (Tuple Int Int Int Int Int Int)))
+
+(assert (= C (table.product A B)))
+
+(declare-fun x () (Tuple Int Int Int))
+(declare-fun y () (Tuple Int Int Int))
+(declare-fun z () (Tuple Int Int Int Int Int Int))
+
+(assert (bag.member x A))
+(assert (bag.member y B))
+(assert (bag.member z C))
+
+(assert (distinct x y ((_ tuple_project 0 1 2) z) ((_ tuple_project 3 4 5) z)))
+
+(check-sat)