Add set.fold operator (#8867)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 23 Jun 2022 14:18:41 +0000 (09:18 -0500)
committerGitHub <noreply@github.com>
Thu, 23 Jun 2022 14:18:41 +0000 (14:18 +0000)
23 files changed:
proofs/lfsc/signatures/theory_def.plf
src/CMakeLists.txt
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/sets/kinds
src/theory/sets/set_reduction.cpp [new file with mode: 0644]
src/theory/sets/set_reduction.h [new file with mode: 0644]
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_rewriter.cpp
src/theory/sets/theory_sets_rewriter.h
src/theory/sets/theory_sets_type_rules.cpp
src/theory/sets/theory_sets_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/sets/set_fold1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/set_fold2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/set_fold3.smt2 [new file with mode: 0644]

index 30f00dc2957174dddd369c2a68264f41e913cac6..c5205f7345cea8fe992a2af32692afdcd93e6e75 100644 (file)
 (define set.filter (# x term (# y term (apply (apply f_set.filter x) y))))
 (declare f_set.map term)
 (define set.map (# x term (# y term (apply (apply f_set.map x) y))))
+(declare f_set.fold term)
+(define set.fold (# x term (# y term (# z term (apply (apply (apply f_set.fold x) y) z)))))
 
 ;; ---- Bags
 (declare bag.empty (! s sort term))
index 995f15c85fbcd0a2977db7648eac4310b349df37..4150c9646e2461a39bd52c5ba87694ae46456a15 100644 (file)
@@ -1012,6 +1012,8 @@ libcvc5_add_sources(
   theory/sets/normal_form.h
   theory/sets/rels_utils.cpp
   theory/sets/rels_utils.h
+  theory/sets/set_reduction.cpp
+  theory/sets/set_reduction.h
   theory/sets/skolem_cache.cpp
   theory/sets/skolem_cache.h
   theory/sets/solver_state.cpp
index ed46fa3ed17e003089e0256129c48917078fdc4d..52c2045b0f8de20aca26195c2d25bc29aa1cf756 100644 (file)
@@ -301,6 +301,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(SET_IS_SINGLETON, internal::Kind::SET_IS_SINGLETON),
         KIND_ENUM(SET_MAP, internal::Kind::SET_MAP),
         KIND_ENUM(SET_FILTER, internal::Kind::SET_FILTER),
+        KIND_ENUM(SET_FOLD, internal::Kind::SET_FOLD),
         /* Relations -------------------------------------------------------- */
         KIND_ENUM(RELATION_JOIN, internal::Kind::RELATION_JOIN),
         KIND_ENUM(RELATION_PRODUCT, internal::Kind::RELATION_PRODUCT),
@@ -624,6 +625,7 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::SET_IS_SINGLETON, SET_IS_SINGLETON},
         {internal::Kind::SET_MAP, SET_MAP},
         {internal::Kind::SET_FILTER, SET_FILTER},
+        {internal::Kind::SET_FOLD, SET_FOLD},
         /* Relations ------------------------------------------------------- */
         {internal::Kind::RELATION_JOIN, RELATION_JOIN},
         {internal::Kind::RELATION_PRODUCT, RELATION_PRODUCT},
index 453e2b2528b0417cb1e3362bda7b306dc01c26ff..9aa4824d7ac30f24194bb5458069dafc68247b22 100644 (file)
@@ -3190,6 +3190,32 @@ enum Kind : int32_t
    * \endrst
    */
    SET_FILTER,
+   /**
+   * Set fold.
+   *
+   * \rst
+   * This operator combines elements of a set into a single value.
+   * (set.fold :math:`f \; t \; A`) folds the elements of set :math:`A`
+   * starting with Term :math:`t` and using the combining function :math:`f`.
+   *
+   * - Arity: ``2``
+   *
+   *   - ``1:`` Term of function Sort :math:`(\rightarrow S_1 \; S_2 \; S_2)`
+   *   - ``2:`` Term of Sort :math:`S_2` (the initial value)
+   *   - ``3:`` Term of bag Sort (Set :math:`S_1`)
+   * \endrst
+   *
+   * - Create Term of this Kind with:
+   *
+   *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  SET_FOLD,
   /* Relations ------------------------------------------------------------- */
 
   /**
@@ -3693,10 +3719,6 @@ enum Kind : int32_t
    *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
    *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
    *
-   * - Create Op of this kind with:
-   *
-   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
-   *
    * \rst
    * .. warning:: This kind is experimental and may be changed or removed in
    *              future versions.
index d410da65088123df4830aac5ad4e0e3bbded17e1..2171954ba5da89481bbb5671f7fc22dcd342b2df 100644 (file)
@@ -97,6 +97,10 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::BAGS_DEQ_DIFF: return "BAGS_DEQ_DIFF";
     case SkolemFunId::SETS_CHOOSE: return "SETS_CHOOSE";
     case SkolemFunId::SETS_DEQ_DIFF: return "SETS_DEQ_DIFF";
+    case SkolemFunId::SETS_FOLD_CARD: return "SETS_FOLD_CARD";
+    case SkolemFunId::SETS_FOLD_COMBINE: return "SETS_FOLD_COMBINE";
+    case SkolemFunId::SETS_FOLD_ELEMENTS: return "SETS_FOLD_ELEMENTS";
+    case SkolemFunId::SETS_FOLD_UNION: return "SETS_FOLD_UNION";
     case SkolemFunId::SETS_MAP_DOWN_ELEMENT: return "SETS_MAP_DOWN_ELEMENT";
     case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
     default: return "?";
index 1d9a74c7a92c44b1009a082cfbf190c477dfbfc1..1b2f3df3bb3fad6453b7984cf1d2c4f3ad048183 100644 (file)
@@ -182,6 +182,10 @@ enum class SkolemFunId
   SETS_CHOOSE,
   /** set diff to witness (not (= A B)) */
   SETS_DEQ_DIFF,
+  SETS_FOLD_CARD,
+  SETS_FOLD_COMBINE,
+  SETS_FOLD_ELEMENTS,
+  SETS_FOLD_UNION,
   /**
    * A skolem variable that is unique per terms (set.map f A), y which is an
    * element in (set.map f A). The skolem is constrained to be an element in A,
index 122f7dbba892c220b94b76e79f06c0172b1f45b1..8c5be875a305330200a58394223b9519608f683a 100644 (file)
@@ -606,6 +606,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(cvc5::SET_IS_SINGLETON, "set.is_singleton");
     addOperator(cvc5::SET_MAP, "set.map");
     addOperator(cvc5::SET_FILTER, "set.filter");
+    addOperator(cvc5::SET_FOLD, "set.fold");
     addOperator(cvc5::RELATION_JOIN, "rel.join");
     addOperator(cvc5::RELATION_PRODUCT, "rel.product");
     addOperator(cvc5::RELATION_TRANSPOSE, "rel.transpose");
index f05825324e5455438aebc75a9b174861f5bf08a4..2f0df00a876b093e2f3589c938685357c177da06 100644 (file)
@@ -1152,6 +1152,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::SET_IS_SINGLETON: return "set.is_singleton";
   case kind::SET_MAP: return "set.map";
   case kind::SET_FILTER: return "set.filter";
+  case kind::SET_FOLD: return "set.fold";
   case kind::RELATION_JOIN: return "rel.join";
   case kind::RELATION_PRODUCT: return "rel.product";
   case kind::RELATION_TRANSPOSE: return "rel.transpose";
index 495971dc95e372c08d2d799bfd719432da512dc4..854070b0cdc55857b86fec89db2f17f2f7bf9e1a 100644 (file)
@@ -335,6 +335,7 @@ const char* toString(InferenceId i)
     case InferenceId::SETS_EQ_MEM_CONFLICT: return "SETS_EQ_MEM_CONFLICT";
     case InferenceId::SETS_FILTER_DOWN: return "SETS_FILTER_DOWN";
     case InferenceId::SETS_FILTER_UP: return "SETS_FILTER_UP";
+    case InferenceId::SETS_FOLD: return "SETS_FOLD";
     case InferenceId::SETS_MAP_DOWN_POSITIVE: return "SETS_MAP_DOWN_POSITIVE";
     case InferenceId::SETS_MAP_UP: return "SETS_MAP_UP";
     case InferenceId::SETS_MEM_EQ: return "SETS_MEM_EQ";
index 4166560e3446a5b94e5a5e4de76c0b27bf1c7473..723345a17c00da14cc4d2a54ad53f6c1bfdffe42 100644 (file)
@@ -488,6 +488,7 @@ enum class InferenceId
   SETS_EQ_MEM_CONFLICT,
   SETS_FILTER_DOWN,
   SETS_FILTER_UP,
+  SETS_FOLD,
   SETS_MAP_DOWN_POSITIVE,
   SETS_MAP_UP,
   SETS_MEM_EQ,
index d1e22cab103fca6ad247dff154b53e640b8bc2f3..e3e5e06b656eb30e1745c0a7a0aca9622a4d610b 100644 (file)
@@ -81,6 +81,14 @@ operator SET_MAP           2  "set map function"
 # and returns the same set excluding those elements that do not satisfy the predicate
 operator SET_FILTER        2  "set filter operator"
 
+# set.fold operator combines elements of a set into a single value.
+# (set.fold f t A) folds the elements of bag A starting with term t and using
+# the combining function f.
+#  f: a binary operation of type (-> T1 T2 T2)
+#  t: an initial value of type T2
+#  A: a bag of type (Set T1)
+operator SET_FOLD          3  "set fold operator"
+
 operator RELATION_JOIN                    2  "relation join"
 operator RELATION_PRODUCT         2  "relation cartesian product"
 operator RELATION_TRANSPOSE    1  "relation transpose"
@@ -104,6 +112,7 @@ typerule SET_CHOOSE         ::cvc5::internal::theory::sets::ChooseTypeRule
 typerule SET_IS_SINGLETON   ::cvc5::internal::theory::sets::IsSingletonTypeRule
 typerule SET_MAP            ::cvc5::internal::theory::sets::SetMapTypeRule
 typerule SET_FILTER         ::cvc5::internal::theory::sets::SetFilterTypeRule
+typerule SET_FOLD           ::cvc5::internal::theory::sets::SetFoldTypeRule
 
 typerule RELATION_JOIN                         ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
 typerule RELATION_PRODUCT              ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
diff --git a/src/theory/sets/set_reduction.cpp b/src/theory/sets/set_reduction.cpp
new file mode 100644 (file)
index 0000000..78b5b80
--- /dev/null
@@ -0,0 +1,125 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed, Andrew Reynolds, Aina Niemetz
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * set reduction.
+ */
+
+#include "theory/sets/set_reduction.h"
+
+#include "expr/bound_var_manager.h"
+#include "expr/emptyset.h"
+#include "expr/skolem_manager.h"
+#include "theory/datatypes/tuple_utils.h"
+#include "theory/quantifiers/fmf/bounded_integers.h"
+#include "util/rational.h"
+
+using namespace cvc5::internal;
+using namespace cvc5::internal::kind;
+
+namespace cvc5::internal {
+namespace theory {
+namespace sets {
+
+SetReduction::SetReduction() {}
+
+SetReduction::~SetReduction() {}
+
+/**
+ * A bound variable corresponding to the universally quantified integer
+ * variable used to range over (may be distinct) elements in a set, used
+ * for axiomatizing the behavior of some term.
+ * If there are multiple quantifiers, this variable should be the first one.
+ */
+struct FirstIndexVarAttributeId
+{
+};
+typedef expr::Attribute<FirstIndexVarAttributeId, Node> FirstIndexVarAttribute;
+
+/**
+ * A bound variable corresponding to the universally quantified integer
+ * variable used to range over (may be distinct) elements in a set, used
+ * for axiomatizing the behavior of some term.
+ * This variable should be the second of multiple quantifiers.
+ */
+struct SecondIndexVarAttributeId
+{
+};
+typedef expr::Attribute<SecondIndexVarAttributeId, Node>
+    SecondIndexVarAttribute;
+
+Node SetReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
+{
+  Assert(node.getKind() == SET_FOLD);
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+  Node f = node[0];
+  Node t = node[1];
+  Node A = node[2];
+  Node zero = nm->mkConstInt(Rational(0));
+  Node one = nm->mkConstInt(Rational(1));
+  // types
+  TypeNode setType = A.getType();
+  TypeNode elementType = A.getType().getSetElementType();
+  TypeNode integerType = nm->integerType();
+  TypeNode ufType = nm->mkFunctionType(integerType, elementType);
+  TypeNode resultType = t.getType();
+  TypeNode combineType = nm->mkFunctionType(integerType, resultType);
+  TypeNode unionType = nm->mkFunctionType(integerType, setType);
+  // skolem functions
+  Node n = sm->mkSkolemFunction(SkolemFunId::SETS_FOLD_CARD, integerType, A);
+  Node uf = sm->mkSkolemFunction(SkolemFunId::SETS_FOLD_ELEMENTS, ufType, A);
+  Node unionNode =
+      sm->mkSkolemFunction(SkolemFunId::SETS_FOLD_UNION, unionType, A);
+  Node combine = sm->mkSkolemFunction(
+      SkolemFunId::SETS_FOLD_COMBINE, combineType, {f, t, A});
+
+  BoundVarManager* bvm = nm->getBoundVarManager();
+  Node i =
+      bvm->mkBoundVar<FirstIndexVarAttribute>(node, "i", nm->integerType());
+  Node iList = nm->mkNode(BOUND_VAR_LIST, i);
+  Node iMinusOne = nm->mkNode(SUB, i, one);
+  Node uf_i = nm->mkNode(APPLY_UF, uf, i);
+  Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
+  Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
+  Node combine_i = nm->mkNode(APPLY_UF, combine, i);
+  Node combine_n = nm->mkNode(APPLY_UF, combine, n);
+  Node union_0 = nm->mkNode(APPLY_UF, unionNode, zero);
+  Node union_iMinusOne = nm->mkNode(APPLY_UF, unionNode, iMinusOne);
+  Node union_i = nm->mkNode(APPLY_UF, unionNode, i);
+  Node union_n = nm->mkNode(APPLY_UF, unionNode, n);
+  Node combine_0_equal = combine_0.eqNode(t);
+  Node combine_i_equal =
+      combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
+  Node union_0_equal = union_0.eqNode(nm->mkConst(EmptySet(setType)));
+  Node singleton = nm->mkNode(SET_SINGLETON, uf_i);
+
+  Node union_i_equal =
+      union_i.eqNode(nm->mkNode(SET_UNION, singleton, union_iMinusOne));
+  Node interval_i =
+      nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
+
+  Node body_i = nm->mkNode(
+      IMPLIES, interval_i, nm->mkNode(AND, combine_i_equal, union_i_equal));
+  Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
+  Node nonNegative = nm->mkNode(GEQ, n, zero);
+  Node union_n_equal = A.eqNode(union_n);
+  asserts.push_back(forAll_i);
+  asserts.push_back(combine_0_equal);
+  asserts.push_back(union_0_equal);
+  asserts.push_back(union_n_equal);
+  asserts.push_back(nonNegative);
+  return combine_n;
+}
+
+}  // namespace sets
+}  // namespace theory
+}  // namespace cvc5::internal
diff --git a/src/theory/sets/set_reduction.h b/src/theory/sets/set_reduction.h
new file mode 100644 (file)
index 0000000..4317201
--- /dev/null
@@ -0,0 +1,73 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * set reduction.
+ */
+
+#ifndef CVC5__THEORY__SETS__SET_REDUCTION_H
+#define CVC5__THEORY__SETS__SET_REDUCTION_H
+
+#include <vector>
+
+#include "cvc5_private.h"
+#include "smt/env_obj.h"
+
+namespace cvc5::internal {
+namespace theory {
+namespace sets {
+
+/**
+ * class for set reductions
+ */
+class SetReduction
+{
+ public:
+  SetReduction();
+  ~SetReduction();
+
+  /**
+   * @param node a term of the form (set.fold f t A) where
+   *        f: (-> T1 T2 T2) is a binary operation
+   *        t: T2 is the initial value
+   *        A: (Set T1) is a set
+   * @param asserts a list of assertions generated by this reduction
+   * @return the reduction term (combine n) with asserts:
+   * - (forall ((i Int))
+   *    (let ((iMinusOne (- i 1)))
+   *      (let ((uf_i (uf i)))
+   *        (=>
+   *          (and (>= i 1) (<= i n))
+   *          (and
+   *            (= (combine i) (f uf_i (combine iMinusOne)))
+   *            (=
+   *              (unionFn i)
+   *              (set.union
+   *                (set.singleton uf_i)
+   *                (unionFn iMinusOne))))))))
+   * - (= (combine 0) t)
+   * - (= (unionFn 0) (as set.empty (Set T1)))
+   * - (= A (unionFn n))
+   * - (>= n 0))
+   * where
+   * n: Int is the cardinality of set A
+   * uf:Int -> T1 is an uninterpreted function that represents elements of A
+   * combine: Int -> T2 is an uninterpreted function
+   * unionFn: Int -> (Set T1) is an uninterpreted function
+   */
+  static Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+};
+
+}  // namespace sets
+}  // namespace theory
+}  // namespace cvc5::internal
+
+#endif /* CVC5__THEORY__SETS__SET_REDUCTION_H */
index 454a86a73cbf6d4a5ca34f2b50f1e04378343235..820f33e3bcf7d418ae77e6709979ff25e8de1ca5 100644 (file)
@@ -16,6 +16,7 @@
 #include "theory/sets/theory_sets.h"
 
 #include "options/sets_options.h"
+#include "theory/sets/set_reduction.h"
 #include "theory/sets/theory_sets_private.h"
 #include "theory/sets/theory_sets_rewriter.h"
 #include "theory/theory_model.h"
@@ -158,6 +159,15 @@ TrustNode TheorySets::ppRewrite(TNode n, std::vector<SkolemLemma>& lems)
       throw LogicException(ss.str());
     }
   }
+  if (nk == SET_FOLD)
+  {
+    std::vector<Node> asserts;
+    Node ret = SetReduction::reduceFoldOperator(n, asserts);
+    NodeManager* nm = NodeManager::currentNM();
+    Node andNode = nm->mkNode(AND, asserts);
+    d_im.lemma(andNode, InferenceId::BAGS_FOLD);
+    return TrustNode::mkTrustRewrite(n, ret, nullptr);
+  }
   return d_internal->ppRewrite(n, lems);
 }
 
index 8c2665eb8eb58947e1e4d7b977cbf9ac903398ac..bf102204821848213a066f6ce74b4f0a11e31446 100644 (file)
@@ -1244,7 +1244,7 @@ void TheorySetsPrivate::processCarePairArgs(TNode a, TNode b)
 
 bool TheorySetsPrivate::isHigherOrderKind(Kind k)
 {
-  return k == SET_MAP || k == SET_FILTER;
+  return k == SET_MAP || k == SET_FILTER || k == SET_FOLD;
 }
 
 Node TheorySetsPrivate::explain(TNode literal)
index 9a0b5a875fb4b2415c4781cb27a977121ef0d08a..8ee9a33d1d82719664f573c95a797a48285f9de1 100644 (file)
@@ -334,6 +334,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
 
   case SET_MAP: return postRewriteMap(node);
   case SET_FILTER: return postRewriteFilter(node);
+  case SET_FOLD: return postRewriteFold(node);
 
   case kind::RELATION_TRANSPOSE:
   {
@@ -705,6 +706,41 @@ RewriteResponse TheorySetsRewriter::postRewriteFilter(TNode n)
   }
 }
 
+RewriteResponse TheorySetsRewriter::postRewriteFold(TNode n)
+{
+  Assert(n.getKind() == kind::SET_FOLD);
+  NodeManager* nm = NodeManager::currentNM();
+  Node f = n[0];
+  Node t = n[1];
+  Kind k = n[2].getKind();
+  switch (k)
+  {
+    case SET_EMPTY:
+    {
+      // ((set.fold f t (as set.empty (Set T))) = t
+      return RewriteResponse(REWRITE_DONE, t);
+    }
+    case SET_SINGLETON:
+    {
+      // (set.fold f t (set.singleton x)) = (f x t)
+      Node x = n[2][0];
+      Node f_x_t = nm->mkNode(APPLY_UF, f, x, t);
+      return RewriteResponse(REWRITE_AGAIN_FULL, f_x_t);
+    }
+    case SET_UNION:
+    {
+      // (set.fold f t (set.union B C)) = (set.fold f (set.fold f t A) B))
+      Node A = n[2][0];
+      Node B = n[2][1];
+      Node foldA = nm->mkNode(SET_FOLD, f, t, A);
+      Node fold = nm->mkNode(SET_FOLD, f, foldA, B);
+      return RewriteResponse(REWRITE_AGAIN_FULL, fold);
+    }
+
+    default: return RewriteResponse(REWRITE_DONE, n);
+  }
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index 74735a878b61282afa25d838fe881da80e81e39e..ba48e403018485a8e93aa03e77fd1cc6781ef53d 100644 (file)
@@ -94,6 +94,14 @@ private:
   *  where p: T -> Bool
   */
  RewriteResponse postRewriteFilter(TNode n);
+ /**
+  *  rewrites for n include:
+  *  - (set.fold f t (as set.empty (Set T))) = t
+  *  - (set.fold f t (set.singleton x)) = (f t x)
+  *  - (set.fold f t (set.union A B)) = (set.fold f (set.fold f t A) B))
+  *  where f: T -> S -> S, and t : S
+  */
+ RewriteResponse postRewriteFold(TNode n);
 }; /* class TheorySetsRewriter */
 
 }  // namespace sets
index 49bd24e171fe8d0e4615258cb2b4f36c7beb0c30..b2eb7987ae6df2ddd994bb38f0ab2abbf6ab5a44 100644 (file)
@@ -356,6 +356,57 @@ TypeNode SetFilterTypeRule::computeType(NodeManager* nodeManager,
   return setType;
 }
 
+TypeNode SetFoldTypeRule::computeType(NodeManager* nodeManager,
+                                      TNode n,
+                                      bool check)
+{
+  Assert(n.getKind() == kind::SET_FOLD);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode initialValueType = n[1].getType(check);
+  TypeNode setType = n[2].getType(check);
+  if (check)
+  {
+    if (!setType.isSet())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "set.fold operator expects a set in the third argument, "
+          "a non-set is found");
+    }
+
+    TypeNode elementType = setType.getSetElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T2 T2) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    TypeNode rangeType = functionType.getRangeType();
+    if (!(argTypes.size() == 2 && argTypes[0] == elementType
+          && argTypes[1] == rangeType))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T2 T2). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    if (rangeType != initialValueType)
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects an initial value of type "
+         << rangeType << ". Found a term of type '" << initialValueType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  TypeNode retType = n[0].getType().getRangeType();
+  return retType;
+}
+
 TypeNode RelBinaryOperatorTypeRule::computeType(NodeManager* nodeManager,
                                                 TNode n,
                                                 bool check)
index 7461ec4cc0bc0c1b31728544f6e9c137e9770cc5..59ea661580372179fc9a5172c1caddd7e70e68c6 100644 (file)
@@ -149,6 +149,15 @@ struct SetFilterTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct SetFilterTypeRule */
 
+/**
+ * Type rule for (set.fold f t A) to make sure f is a binary operation of type
+ * (-> T1 T2 T2), t of type T2, and A is a set of type (Set T1)
+ */
+struct SetFoldTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct SetFoldTypeRule */
+
 /**
  * Type rule for binary operators (rel.join, rel.product) to check
  * if the two arguments are relations (set of tuples).
index 945f31ea8d5751e0f91862d39919a3a0760ed317..eeeae8692d99ccf8b7c1a83a99d03f6df171ba41 100644 (file)
@@ -2510,6 +2510,9 @@ set(regress_1_tests
   regress1/sets/set_filter2.smt2
   regress1/sets/set_filter3.smt2
   regress1/sets/set_filter4.smt2
+  regress1/sets/set_fold1.smt2
+  regress1/sets/set_fold2.smt2
+  regress1/sets/set_fold3.smt2
   regress1/sets/set_map_card_incomplete.smt2
   regress1/sets/set_map_negative_members.smt2
   regress1/sets/set_map_positive_members.smt2
diff --git a/test/regress/cli/regress1/sets/set_fold1.smt2 b/test/regress/cli/regress1/sets/set_fold1.smt2
new file mode 100644 (file)
index 0000000..285e645
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(define-fun plus ((x Int) (y Int)) Int (+ x y))
+(declare-fun A () (Set Int))
+(declare-fun sumPlus1 () Int)
+(declare-fun sumPlus2 () Int)
+(assert (= A (set.insert 1 2 (set.singleton 3))))
+(assert (= sumPlus1 (set.fold plus 1 A)))
+(assert (= sumPlus2 (set.fold plus 2 (as set.empty (Set Int)))))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/set_fold2.smt2 b/test/regress/cli/regress1/sets/set_fold2.smt2
new file mode 100644 (file)
index 0000000..0d86665
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+(define-fun plus ((x Int) (y Int)) Int (+ x y))
+(declare-fun A () (Set Int))
+(declare-fun sumPlus1 () Int)
+(assert (= sumPlus1 (set.fold plus 1 A)))
+(assert (= sumPlus1 10))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/set_fold3.smt2 b/test/regress/cli/regress1/sets/set_fold3.smt2
new file mode 100644 (file)
index 0000000..a5d80f9
--- /dev/null
@@ -0,0 +1,15 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+(set-option :strings-exp true)
+(define-fun min ((x String) (y String)) String (ite (str.< x y) x y))
+(declare-fun A () (Set String))
+(declare-fun x () String)
+(declare-fun minimum () String)
+(assert (= minimum (set.fold min "zzz" A)))
+(assert (str.< "aaa" minimum ))
+(assert (str.< minimum "zzz"))
+(assert (distinct x minimum))
+(assert (set.member x A))
+(check-sat)