Add kind BAG_MAP and its type rule to bags (#6503)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Mon, 30 Aug 2021 23:26:43 +0000 (18:26 -0500)
committerGitHub <noreply@github.com>
Mon, 30 Aug 2021 23:26:43 +0000 (23:26 +0000)
This PR adds kind BAG_MAP to bags.

16 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/kinds
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
test/regress/regress1/bags/map.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_rewriter_white.cpp
test/unit/theory/theory_bags_type_rules_white.cpp

index e245dc4153f0b6ee8e235e3650c8eb3a1c87f3f3..626edf7bbad5df94cd92a1c6e1e230e62418556b 100644 (file)
@@ -308,6 +308,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
     {BAG_IS_SINGLETON, cvc5::Kind::BAG_IS_SINGLETON},
     {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET},
     {BAG_TO_SET, cvc5::Kind::BAG_TO_SET},
+    {BAG_MAP, cvc5::Kind::BAG_MAP},
     /* Strings ------------------------------------------------------------- */
     {STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
     {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP},
@@ -617,6 +618,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::BAG_IS_SINGLETON, BAG_IS_SINGLETON},
         {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET},
         {cvc5::Kind::BAG_TO_SET, BAG_TO_SET},
+        {cvc5::Kind::BAG_MAP,BAG_MAP},
         /* Strings --------------------------------------------------------- */
         {cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
         {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
index e8b876b55ba8b2f1a34d6ad5c6f01ce504d88a1c..94a8a6f92650113c1d6ed721d307fa616352833b 100644 (file)
@@ -2515,6 +2515,21 @@ enum CVC5_EXPORT Kind : int32_t
    *   - `Solver::mkTerm(Kind kind, const Term& child) const`
    */
   BAG_TO_SET,
+  /**
+   * bag.map operator applies the first argument, a function of type (-> T1 T2),
+   * to every element of the second argument, a bag of type (Bag T1),
+   * and returns a bag of type (Bag T2).
+   *
+   * Parameters:
+   *   - 1: a function of type (-> T1 T2)
+   *   - 2: a bag of type (Bag T1)
+   *
+   * Create with:
+   *   - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2)
+   * const`
+   *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+   */
+  BAG_MAP,
 
   /* Strings --------------------------------------------------------------- */
 
index 39492a98c4ee958c67e22ad40446ced6a137aab5..1a0a3d52a28ce5feb2f55894df47d2282d82c89d 100644 (file)
@@ -635,6 +635,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::BAG_IS_SINGLETON, "bag.is_singleton");
     addOperator(api::BAG_FROM_SET, "bag.from_set");
     addOperator(api::BAG_TO_SET, "bag.to_set");
+    addOperator(api::BAG_MAP, "bag.map");
   }
   if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
     defineType("String", d_solver->getStringSort(), true, true);
@@ -1103,7 +1104,7 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
         if ((*i).getSort().isFunction())
         {
           parseError(
-              "Cannot apply equalty to functions unless logic is prefixed by "
+              "Cannot apply equality to functions unless logic is prefixed by "
               "HO_.");
         }
       }
index 523b3efa9944afdb017969442903d80dcef96146..8a23a59ea4e3adfb23d216f12a40534a08bbe6f6 100644 (file)
@@ -1083,6 +1083,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_IS_SINGLETON: return "bag.is_singleton";
   case kind::BAG_FROM_SET: return "bag.from_set";
   case kind::BAG_TO_SET: return "bag.to_set";
+  case kind::BAG_MAP: return "bag.map";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
index b9f620d51af8e80397edb444046bfa633e5ea1a8..f2af950878fc6b7bf61638c9c9edc5e7ea505bb4 100644 (file)
@@ -84,6 +84,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
       case BAG_FROM_SET: response = rewriteFromSet(n); break;
       case BAG_TO_SET: response = rewriteToSet(n); break;
+      case BAG_MAP: response = postRewriteMap(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
   }
@@ -505,6 +506,47 @@ BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
 
+BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
+{
+  Assert(n.getKind() == kind::BAG_MAP);
+  if (n[1].isConst())
+  {
+    // (bag.map f emptybag) = emptybag
+    // (bag.map f (bag "a" 3) = (bag (f "a") 3)
+    std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
+    std::map<Node, Rational> mappedElements;
+    std::map<Node, Rational>::iterator it = elements.begin();
+    while (it != elements.end())
+    {
+      Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], it->first);
+      mappedElements[mappedElement] = it->second;
+      ++it;
+    }
+    TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
+    Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
+    return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
+  }
+  Kind k = n[1].getKind();
+  switch (k)
+  {
+    case MK_BAG:
+    {
+      Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]);
+      Node ret = d_nm->mkNode(MK_BAG, mappedElement, n[1][0]);
+      return BagsRewriteResponse(ret, Rewrite::MAP_MK_BAG);
+    }
+
+    case UNION_DISJOINT:
+    {
+      Node a = d_nm->mkNode(BAG_MAP, n[1][0]);
+      Node b = d_nm->mkNode(BAG_MAP, n[1][1]);
+      Node ret = d_nm->mkNode(UNION_DISJOINT, a, b);
+      return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
+    }
+
+    default: return BagsRewriteResponse(n, Rewrite::NONE);
+  }
+}
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 83f364f9d3dd3eaed2c534ee4ccc04c503f4f471..eb5c9f9ab9dfabd3a89cfc611f365e49042f002b 100644 (file)
@@ -211,6 +211,18 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse postRewriteEqual(const TNode& n) const;
 
+  /**
+   *  rewrites for n include:
+   *  - (bag.map (lambda ((x U)) t) emptybag) = emptybag
+   *  - (bag.map (lambda ((x U)) t) (bag y z)) = (bag (apply (lambda ((x U)) t) y) z)
+   *  - (bag.map (lambda ((x U)) t) (union_disjoint A B)) =
+   *       (union_disjoint
+   *          (bag ((lambda ((x U)) t) "a") 3)
+   *          (bag ((lambda ((x U)) t) "b") 4))
+   *
+   */
+  BagsRewriteResponse postRewriteMap(const TNode& n) const;
+
  private:
   /** Reference to the rewriter statistics. */
   NodeManager* d_nm;
index 79541023905027cf88a455060962b002ec76471b..55fd28695c632ad12c625e816c9ce3b9017cea74 100644 (file)
@@ -72,6 +72,10 @@ operator BAG_TO_SET        1  "converts a bag to a set"
 # If the bag has cardinality > 1, then (choose A) will deterministically return an element in A.
 operator BAG_CHOOSE        1  "return an element in the bag given as a parameter"
 
+# The bag.map operator applies the first argument, a function of type (-> T1 T2), to every element
+# of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2).
+operator BAG_MAP           2  "bag map function"
+
 typerule UNION_MAX           ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule UNION_DISJOINT      ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule INTERSECTION_MIN    ::cvc5::theory::bags::BinaryOperatorTypeRule
@@ -88,6 +92,7 @@ typerule BAG_CHOOSE          ::cvc5::theory::bags::ChooseTypeRule
 typerule BAG_IS_SINGLETON    ::cvc5::theory::bags::IsSingletonTypeRule
 typerule BAG_FROM_SET        ::cvc5::theory::bags::FromSetTypeRule
 typerule BAG_TO_SET          ::cvc5::theory::bags::ToSetTypeRule
+typerule BAG_MAP            ::cvc5::theory::bags::BagMapTypeRule
 
 construle UNION_DISJOINT     ::cvc5::theory::bags::BinaryOperatorTypeRule
 construle MK_BAG             ::cvc5::theory::bags::MkBagTypeRule
index ec32d01381f878f8c097902585f191d9ac725da1..58445de5939a1bbefa6eae10ce16db1397823f47 100644 (file)
@@ -109,6 +109,7 @@ Node NormalForm::evaluate(TNode n)
     case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
     case BAG_FROM_SET: return evaluateFromSet(n);
     case BAG_TO_SET: return evaluateToSet(n);
+    case BAG_MAP: return evaluateBagMap(n);
     default: break;
   }
   Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
@@ -675,6 +676,35 @@ Node NormalForm::evaluateToSet(TNode n)
   return set;
 }
 
+
+Node NormalForm::evaluateBagMap(TNode n)
+{
+  Assert(n.getKind() == BAG_MAP);
+
+  // Examples
+  // --------
+  // - (bag.map ((lambda ((x String)) "z")
+  //            (union_disjoint (bag "a" 2) (bag "b" 3)) =
+  //     (union_disjoint
+  //       (bag ((lambda ((x String)) "z") "a") 2)
+  //       (bag ((lambda ((x String)) "z") "b") 3)) =
+  //     (bag "z" 5)
+
+  std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
+  std::map<Node, Rational> mappedElements;
+  std::map<Node, Rational>::iterator it = elements.begin();
+  NodeManager* nm = NodeManager::currentNM();
+  while (it != elements.end())
+  {
+    Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
+    mappedElements[mappedElement] = it->second;
+    ++it;
+  }
+  TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
+  Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
+  return ret;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 124ecdf5f066f4c980f280f5e3577e688d0903e1..f104e0381c03d824c7c915e8f36268e0b61bb138 100644 (file)
@@ -190,6 +190,11 @@ class NormalForm
    * @return a constant set constructed from the elements in A.
    */
   static Node evaluateToSet(TNode n);
+  /**
+   * @param n has the form (bag.map f A) where A is a constant bag
+   * @return a constant bag constructed from the images of elements in A.
+   */
+  static Node evaluateBagMap(TNode n);
 };
 }  // namespace bags
 }  // namespace theory
index ff77c4187adc4ad99b333e7713966f32f41a77c5..c8aeec14758991e7debc500001bafcf0e0bce83d 100644 (file)
@@ -44,6 +44,9 @@ const char* toString(Rewrite r)
     case Rewrite::INTERSECTION_SHARED_LEFT: return "INTERSECTION_SHARED_LEFT";
     case Rewrite::INTERSECTION_SHARED_RIGHT: return "INTERSECTION_SHARED_RIGHT";
     case Rewrite::IS_SINGLETON_MK_BAG: return "IS_SINGLETON_MK_BAG";
+    case Rewrite::MAP_CONST: return "MAP_CONST";
+    case Rewrite::MAP_MK_BAG: return "MAP_MK_BAG";
+    case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT";
     case Rewrite::MK_BAG_COUNT_NEGATIVE: return "MK_BAG_COUNT_NEGATIVE";
     case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
     case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
index f5977332a8fdb9d35b2a5403b8c21da4c365f548..78eb502c88c0fe058ea68f4c2eb3a7e9e5dd97d7 100644 (file)
@@ -49,6 +49,9 @@ enum class Rewrite : uint32_t
   INTERSECTION_SHARED_LEFT,
   INTERSECTION_SHARED_RIGHT,
   IS_SINGLETON_MK_BAG,
+  MAP_CONST,
+  MAP_MK_BAG,
+  MAP_UNION_DISJOINT,
   MK_BAG_COUNT_NEGATIVE,
   REMOVE_FROM_UNION,
   REMOVE_MIN,
index d820ce6e1612595f0ac6fd50bc5d4a46172dbc23..7f45b9b1add0c20441d5ce35babe4515ac3b7f2b 100644 (file)
@@ -283,6 +283,48 @@ TypeNode ToSetTypeRule::computeType(NodeManager* nodeManager,
   return setType;
 }
 
+TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager,
+                                     TNode n,
+                                     bool check)
+{
+  Assert(n.getKind() == kind::BAG_MAP);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode bagType = n[1].getType(check);
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "bag.map operator expects a bag in the second argument, "
+          "a non-bag is found");
+    }
+
+    TypeNode elementType = bagType.getBagElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " *) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    if (!(argTypes.size() == 1 && argTypes[0] == elementType))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " *). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  TypeNode rangeType = n[0].getType().getRangeType();
+  TypeNode retType = nodeManager->mkBagType(rangeType);
+  return retType;
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;
index 4874233094f38eb4eccea8dfab44c948d2cb8ae5..53a63a6876522fa5007aa6b2c6d5adfa0c2920fb 100644 (file)
@@ -123,6 +123,15 @@ struct ToSetTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct ToSetTypeRule */
 
+/**
+ * Type rule for (bag.map f B) to make sure f is a unary function of type
+ * (-> T1 T2) where B is a bag of type (Bag T1)
+ */
+struct BagMapTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagMapTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
diff --git a/test/regress/regress1/bags/map.smt2 b/test/regress/regress1/bags/map.smt2
new file mode 100644 (file)
index 0000000..54d6714
--- /dev/null
@@ -0,0 +1,12 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(declare-fun f (Int) Int)
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= B (bag.map f A)))
+(assert (distinct (f x) (f y) x y))
+(check-sat)
index f70ff0c5dcb36c112d4bcebd1b5fbe3e47b6971c..e63fb3b20b3a272d36e86de07fb4d65b4619ce2c 100644 (file)
@@ -694,5 +694,52 @@ TEST_F(TestTheoryWhiteBagsRewriter, to_set)
   ASSERT_TRUE(response.d_node == singleton
               && response.d_status == REWRITE_AGAIN_FULL);
 }
+
+TEST_F(TestTheoryWhiteBagsRewriter, map)
+{
+  Node emptybagString =
+      d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
+
+  Node one = d_nodeManager->mkConst(Rational(1));
+  Node x = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType());
+  std::vector<Node> args;
+  args.push_back(x);
+  Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args);
+  Node lambda = d_nodeManager->mkNode(LAMBDA, bound, one);
+
+  // (bag.map (lambda ((x U))  t) emptybag) = emptybag
+  Node n1 = d_nodeManager->mkNode(BAG_MAP, lambda, emptybagString);
+  RewriteResponse response1 = d_rewriter->postRewrite(n1);
+  TypeNode type = d_nodeManager->mkBagType(d_nodeManager->integerType());
+  Node emptybagInteger = d_nodeManager->mkConst(EmptyBag(type));
+  ASSERT_TRUE(response1.d_node == emptybagInteger
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+  std::vector<Node> elements = getNStrings(2);
+  Node a = d_nodeManager->mkConst(String("a"));
+  Node b = d_nodeManager->mkConst(String("b"));
+  Node A = d_nodeManager->mkBag(d_nodeManager->stringType(),
+                                a,
+                                d_nodeManager->mkConst(Rational(3)));
+  Node B = d_nodeManager->mkBag(d_nodeManager->stringType(),
+                                b,
+                                d_nodeManager->mkConst(Rational(4)));
+  Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B);
+
+  ASSERT_TRUE(unionDisjointAB.isConst());
+  // - (bag.map (lambda ((x Int)) 1) (union_disjoint (bag "a" 3) (bag "b" 4))) =
+  //        (bag 1 7))
+  Node n2 = d_nodeManager->mkNode(BAG_MAP, lambda, unionDisjointAB);
+
+  std::cout << n2 << std::endl;
+
+  Node rewritten = Rewriter:: rewrite(n2);
+  std::cout << rewritten << std::endl;
+
+  Node bag = d_nodeManager->mkBag(d_nodeManager->integerType(),
+                                  one,               d_nodeManager->mkConst(Rational(7)));
+  ASSERT_TRUE(rewritten == bag);
+}
+
 }  // namespace test
 }  // namespace cvc5
index 8013d06ea842e1a6b745fc79253db17681ad648e..eace59c96cf30875cbe36082bd763e719a782c03 100644 (file)
@@ -111,5 +111,40 @@ TEST_F(TestTheoryWhiteBagsTypeRule, to_set_operator)
   ASSERT_NO_THROW(d_nodeManager->mkNode(BAG_TO_SET, bag));
   ASSERT_TRUE(d_nodeManager->mkNode(BAG_TO_SET, bag).getType().isSet());
 }
+
+TEST_F(TestTheoryWhiteBagsTypeRule, map_operator)
+{
+  std::vector<Node> elements = getNStrings(1);
+  Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(),
+                                  elements[0],
+                                  d_nodeManager->mkConst(Rational(10)));
+  Node set =
+      d_nodeManager->mkSingleton(d_nodeManager->stringType(), elements[0]);
+
+  Node x1 = d_nodeManager->mkBoundVar("x", d_nodeManager->stringType());
+  Node length = d_nodeManager->mkNode(STRING_LENGTH, x1);
+  std::vector<Node> args1;
+  args1.push_back(x1);
+  Node bound1 = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args1);
+  Node lambda1 = d_nodeManager->mkNode(LAMBDA, bound1, length);
+
+  ASSERT_NO_THROW(d_nodeManager->mkNode(BAG_MAP, lambda1, bag));
+  Node mappedBag = d_nodeManager->mkNode(BAG_MAP, lambda1, bag);
+  ASSERT_TRUE(mappedBag.getType().isBag());
+  ASSERT_EQ(d_nodeManager->integerType(),
+            mappedBag.getType().getBagElementType());
+
+  Node one = d_nodeManager->mkConst(Rational(1));
+  Node x2 = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType());
+  std::vector<Node> args2;
+  args2.push_back(x2);
+  Node bound2 = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args2);
+  Node lambda2 = d_nodeManager->mkNode(LAMBDA, bound2, one);
+  ASSERT_THROW(d_nodeManager->mkNode(BAG_MAP, lambda2, bag).getType(true),
+               TypeCheckingExceptionPrivate);
+  ASSERT_THROW(d_nodeManager->mkNode(BAG_MAP, lambda2, set).getType(true),
+               TypeCheckingExceptionPrivate);
+}
+
 }  // namespace test
 }  // namespace cvc5