From: mudathirmahgoub Date: Mon, 30 Aug 2021 23:26:43 +0000 (-0500) Subject: Add kind BAG_MAP and its type rule to bags (#6503) X-Git-Tag: cvc5-1.0.0~1318 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=22c83c12097da6105e6f03e0df70385527e651a4;p=cvc5.git Add kind BAG_MAP and its type rule to bags (#6503) This PR adds kind BAG_MAP to bags. --- diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index e245dc415..626edf7bb 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -308,6 +308,7 @@ const static std::unordered_map s_kinds{ {BAG_IS_SINGLETON, cvc5::Kind::BAG_IS_SINGLETON}, {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET}, {BAG_TO_SET, cvc5::Kind::BAG_TO_SET}, + {BAG_MAP, cvc5::Kind::BAG_MAP}, /* Strings ------------------------------------------------------------- */ {STRING_CONCAT, cvc5::Kind::STRING_CONCAT}, {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP}, @@ -617,6 +618,7 @@ const static std::unordered_map {cvc5::Kind::BAG_IS_SINGLETON, BAG_IS_SINGLETON}, {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET}, {cvc5::Kind::BAG_TO_SET, BAG_TO_SET}, + {cvc5::Kind::BAG_MAP,BAG_MAP}, /* Strings --------------------------------------------------------- */ {cvc5::Kind::STRING_CONCAT, STRING_CONCAT}, {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index e8b876b55..94a8a6f92 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2515,6 +2515,21 @@ enum CVC5_EXPORT Kind : int32_t * - `Solver::mkTerm(Kind kind, const Term& child) const` */ BAG_TO_SET, + /** + * bag.map operator applies the first argument, a function of type (-> T1 T2), + * to every element of the second argument, a bag of type (Bag T1), + * and returns a bag of type (Bag T2). + * + * Parameters: + * - 1: a function of type (-> T1 T2) + * - 2: a bag of type (Bag T1) + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) + * const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + BAG_MAP, /* Strings --------------------------------------------------------------- */ diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 39492a98c..1a0a3d52a 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -635,6 +635,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_IS_SINGLETON, "bag.is_singleton"); addOperator(api::BAG_FROM_SET, "bag.from_set"); addOperator(api::BAG_TO_SET, "bag.to_set"); + addOperator(api::BAG_MAP, "bag.map"); } if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) { defineType("String", d_solver->getStringSort(), true, true); @@ -1103,7 +1104,7 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) if ((*i).getSort().isFunction()) { parseError( - "Cannot apply equalty to functions unless logic is prefixed by " + "Cannot apply equality to functions unless logic is prefixed by " "HO_."); } } diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 523b3efa9..8a23a59ea 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1083,6 +1083,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_IS_SINGLETON: return "bag.is_singleton"; case kind::BAG_FROM_SET: return "bag.from_set"; case kind::BAG_TO_SET: return "bag.to_set"; + case kind::BAG_MAP: return "bag.map"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index b9f620d51..f2af95087 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -84,6 +84,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break; case BAG_FROM_SET: response = rewriteFromSet(n); break; case BAG_TO_SET: response = rewriteToSet(n); break; + case BAG_MAP: response = postRewriteMap(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } } @@ -505,6 +506,47 @@ BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const return BagsRewriteResponse(n, Rewrite::NONE); } +BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const +{ + Assert(n.getKind() == kind::BAG_MAP); + if (n[1].isConst()) + { + // (bag.map f emptybag) = emptybag + // (bag.map f (bag "a" 3) = (bag (f "a") 3) + std::map elements = NormalForm::getBagElements(n[1]); + std::map mappedElements; + std::map::iterator it = elements.begin(); + while (it != elements.end()) + { + Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], it->first); + mappedElements[mappedElement] = it->second; + ++it; + } + TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType()); + Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); + return BagsRewriteResponse(ret, Rewrite::MAP_CONST); + } + Kind k = n[1].getKind(); + switch (k) + { + case MK_BAG: + { + Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]); + Node ret = d_nm->mkNode(MK_BAG, mappedElement, n[1][0]); + return BagsRewriteResponse(ret, Rewrite::MAP_MK_BAG); + } + + case UNION_DISJOINT: + { + Node a = d_nm->mkNode(BAG_MAP, n[1][0]); + Node b = d_nm->mkNode(BAG_MAP, n[1][1]); + Node ret = d_nm->mkNode(UNION_DISJOINT, a, b); + return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT); + } + + default: return BagsRewriteResponse(n, Rewrite::NONE); + } +} } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 83f364f9d..eb5c9f9ab 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -211,6 +211,18 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse postRewriteEqual(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.map (lambda ((x U)) t) emptybag) = emptybag + * - (bag.map (lambda ((x U)) t) (bag y z)) = (bag (apply (lambda ((x U)) t) y) z) + * - (bag.map (lambda ((x U)) t) (union_disjoint A B)) = + * (union_disjoint + * (bag ((lambda ((x U)) t) "a") 3) + * (bag ((lambda ((x U)) t) "b") 4)) + * + */ + BagsRewriteResponse postRewriteMap(const TNode& n) const; + private: /** Reference to the rewriter statistics. */ NodeManager* d_nm; diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 795410239..55fd28695 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -72,6 +72,10 @@ operator BAG_TO_SET 1 "converts a bag to a set" # If the bag has cardinality > 1, then (choose A) will deterministically return an element in A. operator BAG_CHOOSE 1 "return an element in the bag given as a parameter" +# The bag.map operator applies the first argument, a function of type (-> T1 T2), to every element +# of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2). +operator BAG_MAP 2 "bag map function" + typerule UNION_MAX ::cvc5::theory::bags::BinaryOperatorTypeRule typerule UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule typerule INTERSECTION_MIN ::cvc5::theory::bags::BinaryOperatorTypeRule @@ -88,6 +92,7 @@ typerule BAG_CHOOSE ::cvc5::theory::bags::ChooseTypeRule typerule BAG_IS_SINGLETON ::cvc5::theory::bags::IsSingletonTypeRule typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule +typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule construle UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule construle MK_BAG ::cvc5::theory::bags::MkBagTypeRule diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index ec32d0138..58445de59 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -109,6 +109,7 @@ Node NormalForm::evaluate(TNode n) case BAG_IS_SINGLETON: return evaluateIsSingleton(n); case BAG_FROM_SET: return evaluateFromSet(n); case BAG_TO_SET: return evaluateToSet(n); + case BAG_MAP: return evaluateBagMap(n); default: break; } Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n @@ -675,6 +676,35 @@ Node NormalForm::evaluateToSet(TNode n) return set; } + +Node NormalForm::evaluateBagMap(TNode n) +{ + Assert(n.getKind() == BAG_MAP); + + // Examples + // -------- + // - (bag.map ((lambda ((x String)) "z") + // (union_disjoint (bag "a" 2) (bag "b" 3)) = + // (union_disjoint + // (bag ((lambda ((x String)) "z") "a") 2) + // (bag ((lambda ((x String)) "z") "b") 3)) = + // (bag "z" 5) + + std::map elements = NormalForm::getBagElements(n[1]); + std::map mappedElements; + std::map::iterator it = elements.begin(); + NodeManager* nm = NodeManager::currentNM(); + while (it != elements.end()) + { + Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first); + mappedElements[mappedElement] = it->second; + ++it; + } + TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); + Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); + return ret; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index 124ecdf5f..f104e0381 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -190,6 +190,11 @@ class NormalForm * @return a constant set constructed from the elements in A. */ static Node evaluateToSet(TNode n); + /** + * @param n has the form (bag.map f A) where A is a constant bag + * @return a constant bag constructed from the images of elements in A. + */ + static Node evaluateBagMap(TNode n); }; } // namespace bags } // namespace theory diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index ff77c4187..c8aeec147 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -44,6 +44,9 @@ const char* toString(Rewrite r) case Rewrite::INTERSECTION_SHARED_LEFT: return "INTERSECTION_SHARED_LEFT"; case Rewrite::INTERSECTION_SHARED_RIGHT: return "INTERSECTION_SHARED_RIGHT"; case Rewrite::IS_SINGLETON_MK_BAG: return "IS_SINGLETON_MK_BAG"; + case Rewrite::MAP_CONST: return "MAP_CONST"; + case Rewrite::MAP_MK_BAG: return "MAP_MK_BAG"; + case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT"; case Rewrite::MK_BAG_COUNT_NEGATIVE: return "MK_BAG_COUNT_NEGATIVE"; case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION"; case Rewrite::REMOVE_MIN: return "REMOVE_MIN"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index f5977332a..78eb502c8 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -49,6 +49,9 @@ enum class Rewrite : uint32_t INTERSECTION_SHARED_LEFT, INTERSECTION_SHARED_RIGHT, IS_SINGLETON_MK_BAG, + MAP_CONST, + MAP_MK_BAG, + MAP_UNION_DISJOINT, MK_BAG_COUNT_NEGATIVE, REMOVE_FROM_UNION, REMOVE_MIN, diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index d820ce6e1..7f45b9b1a 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -283,6 +283,48 @@ TypeNode ToSetTypeRule::computeType(NodeManager* nodeManager, return setType; } +TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_MAP); + TypeNode functionType = n[0].getType(check); + TypeNode bagType = n[1].getType(check); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, + "bag.map operator expects a bag in the second argument, " + "a non-bag is found"); + } + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " *) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + if (!(argTypes.size() == 1 && argTypes[0] == elementType)) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " *). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + TypeNode rangeType = n[0].getType().getRangeType(); + TypeNode retType = nodeManager->mkBagType(rangeType); + return retType; +} + Cardinality BagsProperties::computeCardinality(TypeNode type) { return Cardinality::INTEGERS; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 487423309..53a63a687 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -123,6 +123,15 @@ struct ToSetTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct ToSetTypeRule */ +/** + * Type rule for (bag.map f B) to make sure f is a unary function of type + * (-> T1 T2) where B is a bag of type (Bag T1) + */ +struct BagMapTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagMapTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/test/regress/regress1/bags/map.smt2 b/test/regress/regress1/bags/map.smt2 new file mode 100644 index 000000000..54d671415 --- /dev/null +++ b/test/regress/regress1/bags/map.smt2 @@ -0,0 +1,12 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(declare-fun f (Int) Int) +(assert (= A (union_max (bag x 1) (bag y 2)))) +(assert (= A (union_max (bag x 1) (bag y 2)))) +(assert (= B (bag.map f A))) +(assert (distinct (f x) (f y) x y)) +(check-sat) diff --git a/test/unit/theory/theory_bags_rewriter_white.cpp b/test/unit/theory/theory_bags_rewriter_white.cpp index f70ff0c5d..e63fb3b20 100644 --- a/test/unit/theory/theory_bags_rewriter_white.cpp +++ b/test/unit/theory/theory_bags_rewriter_white.cpp @@ -694,5 +694,52 @@ TEST_F(TestTheoryWhiteBagsRewriter, to_set) ASSERT_TRUE(response.d_node == singleton && response.d_status == REWRITE_AGAIN_FULL); } + +TEST_F(TestTheoryWhiteBagsRewriter, map) +{ + Node emptybagString = + d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType())); + + Node one = d_nodeManager->mkConst(Rational(1)); + Node x = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType()); + std::vector args; + args.push_back(x); + Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args); + Node lambda = d_nodeManager->mkNode(LAMBDA, bound, one); + + // (bag.map (lambda ((x U)) t) emptybag) = emptybag + Node n1 = d_nodeManager->mkNode(BAG_MAP, lambda, emptybagString); + RewriteResponse response1 = d_rewriter->postRewrite(n1); + TypeNode type = d_nodeManager->mkBagType(d_nodeManager->integerType()); + Node emptybagInteger = d_nodeManager->mkConst(EmptyBag(type)); + ASSERT_TRUE(response1.d_node == emptybagInteger + && response1.d_status == REWRITE_AGAIN_FULL); + + std::vector elements = getNStrings(2); + Node a = d_nodeManager->mkConst(String("a")); + Node b = d_nodeManager->mkConst(String("b")); + Node A = d_nodeManager->mkBag(d_nodeManager->stringType(), + a, + d_nodeManager->mkConst(Rational(3))); + Node B = d_nodeManager->mkBag(d_nodeManager->stringType(), + b, + d_nodeManager->mkConst(Rational(4))); + Node unionDisjointAB = d_nodeManager->mkNode(UNION_DISJOINT, A, B); + + ASSERT_TRUE(unionDisjointAB.isConst()); + // - (bag.map (lambda ((x Int)) 1) (union_disjoint (bag "a" 3) (bag "b" 4))) = + // (bag 1 7)) + Node n2 = d_nodeManager->mkNode(BAG_MAP, lambda, unionDisjointAB); + + std::cout << n2 << std::endl; + + Node rewritten = Rewriter:: rewrite(n2); + std::cout << rewritten << std::endl; + + Node bag = d_nodeManager->mkBag(d_nodeManager->integerType(), + one, d_nodeManager->mkConst(Rational(7))); + ASSERT_TRUE(rewritten == bag); +} + } // namespace test } // namespace cvc5 diff --git a/test/unit/theory/theory_bags_type_rules_white.cpp b/test/unit/theory/theory_bags_type_rules_white.cpp index 8013d06ea..eace59c96 100644 --- a/test/unit/theory/theory_bags_type_rules_white.cpp +++ b/test/unit/theory/theory_bags_type_rules_white.cpp @@ -111,5 +111,40 @@ TEST_F(TestTheoryWhiteBagsTypeRule, to_set_operator) ASSERT_NO_THROW(d_nodeManager->mkNode(BAG_TO_SET, bag)); ASSERT_TRUE(d_nodeManager->mkNode(BAG_TO_SET, bag).getType().isSet()); } + +TEST_F(TestTheoryWhiteBagsTypeRule, map_operator) +{ + std::vector elements = getNStrings(1); + Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), + elements[0], + d_nodeManager->mkConst(Rational(10))); + Node set = + d_nodeManager->mkSingleton(d_nodeManager->stringType(), elements[0]); + + Node x1 = d_nodeManager->mkBoundVar("x", d_nodeManager->stringType()); + Node length = d_nodeManager->mkNode(STRING_LENGTH, x1); + std::vector args1; + args1.push_back(x1); + Node bound1 = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args1); + Node lambda1 = d_nodeManager->mkNode(LAMBDA, bound1, length); + + ASSERT_NO_THROW(d_nodeManager->mkNode(BAG_MAP, lambda1, bag)); + Node mappedBag = d_nodeManager->mkNode(BAG_MAP, lambda1, bag); + ASSERT_TRUE(mappedBag.getType().isBag()); + ASSERT_EQ(d_nodeManager->integerType(), + mappedBag.getType().getBagElementType()); + + Node one = d_nodeManager->mkConst(Rational(1)); + Node x2 = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType()); + std::vector args2; + args2.push_back(x2); + Node bound2 = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, args2); + Node lambda2 = d_nodeManager->mkNode(LAMBDA, bound2, one); + ASSERT_THROW(d_nodeManager->mkNode(BAG_MAP, lambda2, bag).getType(true), + TypeCheckingExceptionPrivate); + ASSERT_THROW(d_nodeManager->mkNode(BAG_MAP, lambda2, set).getType(true), + TypeCheckingExceptionPrivate); +} + } // namespace test } // namespace cvc5