From 70997d0e3ebf2027279373d9594c66119f3fa656 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Wed, 1 Dec 2021 20:12:30 -0600 Subject: [PATCH] add bag.fold operator (#7718) --- src/CMakeLists.txt | 2 + src/api/cpp/cvc5.cpp | 2 + src/api/cpp/cvc5_kind.h | 16 +++ src/expr/skolem_manager.cpp | 4 + src/expr/skolem_manager.h | 4 + src/parser/smt2/smt2.cpp | 1 + src/printer/smt2/smt2_printer.cpp | 1 + src/theory/bags/bag_reduction.cpp | 119 ++++++++++++++++++ src/theory/bags/bag_reduction.h | 77 ++++++++++++ src/theory/bags/bags_rewriter.cpp | 40 ++++++ src/theory/bags/bags_rewriter.h | 10 ++ src/theory/bags/kinds | 9 ++ src/theory/bags/normal_form.cpp | 38 +++++- src/theory/bags/normal_form.h | 6 + src/theory/bags/rewrites.cpp | 3 + src/theory/bags/rewrites.h | 3 + src/theory/bags/theory_bags.cpp | 20 ++- src/theory/bags/theory_bags.h | 4 + src/theory/bags/theory_bags_type_rules.cpp | 51 ++++++++ src/theory/bags/theory_bags_type_rules.h | 9 ++ src/theory/inference_id.cpp | 1 + src/theory/inference_id.h | 1 + test/regress/CMakeLists.txt | 3 + test/regress/regress1/bags/fold1.smt2 | 10 ++ test/regress/regress1/bags/fold2.smt2 | 15 +++ .../theory/theory_bags_rewriter_white.cpp | 59 ++++++++- 26 files changed, 502 insertions(+), 6 deletions(-) create mode 100644 src/theory/bags/bag_reduction.cpp create mode 100644 src/theory/bags/bag_reduction.h create mode 100644 test/regress/regress1/bags/fold1.smt2 create mode 100644 test/regress/regress1/bags/fold2.smt2 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 025f499e6..96de9afeb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -535,6 +535,8 @@ libcvc5_add_sources( theory/bags/bags_rewriter.h theory/bags/bag_solver.cpp theory/bags/bag_solver.h + theory/bags/bag_reduction.cpp + theory/bags/bag_reduction.h theory/bags/bags_statistics.cpp theory/bags/bags_statistics.h theory/bags/infer_info.cpp diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 6129ff891..c62dde511 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -313,6 +313,7 @@ const static std::unordered_map s_kinds{ {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET}, {BAG_TO_SET, cvc5::Kind::BAG_TO_SET}, {BAG_MAP, cvc5::Kind::BAG_MAP}, + {BAG_FOLD, cvc5::Kind::BAG_FOLD}, /* Strings ------------------------------------------------------------- */ {STRING_CONCAT, cvc5::Kind::STRING_CONCAT}, {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP}, @@ -624,6 +625,7 @@ const static std::unordered_map {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET}, {cvc5::Kind::BAG_TO_SET, BAG_TO_SET}, {cvc5::Kind::BAG_MAP, BAG_MAP}, + {cvc5::Kind::BAG_FOLD, BAG_FOLD}, /* Strings --------------------------------------------------------- */ {cvc5::Kind::STRING_CONCAT, STRING_CONCAT}, {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index e6a03cbe4..73843f9b5 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2539,6 +2539,22 @@ enum Kind : int32_t * - `Solver::mkTerm(Kind kind, const std::vector& children) const` */ BAG_MAP, + /** + * bag.fold operator combines elements of a bag into a single value. + * (bag.fold f t B) folds the elements of bag B starting with term t and using + * the combining function f. + * + * Parameters: + * - 1: a binary operation of type (-> T1 T2 T2) + * - 2: an initial value of type T2 + * - 2: a bag of type (Bag T1) + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2, + * const Term& child3) const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + BAG_FOLD, /* Strings --------------------------------------------------------------- */ diff --git a/src/expr/skolem_manager.cpp b/src/expr/skolem_manager.cpp index db976559f..476517820 100644 --- a/src/expr/skolem_manager.cpp +++ b/src/expr/skolem_manager.cpp @@ -68,6 +68,10 @@ const char* toString(SkolemFunId id) case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST"; case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT"; case SkolemFunId::BAGS_CHOOSE: return "BAGS_CHOOSE"; + case SkolemFunId::BAGS_FOLD_CARD: return "BAGS_FOLD_CARD"; + case SkolemFunId::BAGS_FOLD_COMBINE: return "BAGS_FOLD_COMBINE"; + case SkolemFunId::BAGS_FOLD_ELEMENTS: return "BAGS_FOLD_ELEMENTS"; + case SkolemFunId::BAGS_FOLD_UNION_DISJOINT: return "BAGS_FOLD_UNION_DISJOINT"; case SkolemFunId::BAGS_MAP_PREIMAGE: return "BAGS_MAP_PREIMAGE"; case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM"; case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED"; diff --git a/src/expr/skolem_manager.h b/src/expr/skolem_manager.h index a18de8a2e..780413d17 100644 --- a/src/expr/skolem_manager.h +++ b/src/expr/skolem_manager.h @@ -112,6 +112,10 @@ enum class SkolemFunId * i = 0, ..., n. */ RE_UNFOLD_POS_COMPONENT, + BAGS_FOLD_CARD, + BAGS_FOLD_COMBINE, + BAGS_FOLD_ELEMENTS, + BAGS_FOLD_UNION_DISJOINT, /** An interpreted function for bag.choose operator: * (bag.choose A) is expanded as * (witness ((x elementType)) diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index ad380a31c..4e1a8aae8 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -629,6 +629,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_FROM_SET, "bag.from_set"); addOperator(api::BAG_TO_SET, "bag.to_set"); addOperator(api::BAG_MAP, "bag.map"); + addOperator(api::BAG_FOLD, "bag.fold"); } if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) { defineType("String", d_solver->getStringSort(), true, true); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 13477b792..875ca7dc2 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1098,6 +1098,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_FROM_SET: return "bag.from_set"; case kind::BAG_TO_SET: return "bag.to_set"; case kind::BAG_MAP: return "bag.map"; + case kind::BAG_FOLD: return "bag.fold"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp new file mode 100644 index 000000000..9203a1c45 --- /dev/null +++ b/src/theory/bags/bag_reduction.cpp @@ -0,0 +1,119 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * bag reduction. + */ + +#include "theory/bags/bag_reduction.h" + +#include "expr/bound_var_manager.h" +#include "expr/emptybag.h" +#include "expr/skolem_manager.h" +#include "theory/quantifiers/fmf/bounded_integers.h" +#include "util/rational.h" + +using namespace cvc5; +using namespace cvc5::kind; + +namespace cvc5 { +namespace theory { +namespace bags { + +BagReduction::BagReduction(Env& env) : EnvObj(env) {} + +BagReduction::~BagReduction() {} + +/** + * A bound variable corresponding to the universally quantified integer + * variable used to range over the distinct elements in a bag, used + * for axiomatizing the behavior of some term. + */ +struct IndexVarAttributeId +{ +}; +typedef expr::Attribute IndexVarAttribute; + +Node BagReduction::reduceFoldOperator(Node node, std::vector& asserts) +{ + Assert(node.getKind() == BAG_FOLD); + if (d_env.getLogicInfo().isHigherOrder()) + { + NodeManager* nm = NodeManager::currentNM(); + SkolemManager* sm = nm->getSkolemManager(); + Node f = node[0]; + Node t = node[1]; + Node A = node[2]; + Node zero = nm->mkConst(CONST_RATIONAL, Rational(0)); + Node one = nm->mkConst(CONST_RATIONAL, Rational(1)); + // types + TypeNode bagType = A.getType(); + TypeNode elementType = A.getType().getBagElementType(); + TypeNode integerType = nm->integerType(); + TypeNode ufType = nm->mkFunctionType(integerType, elementType); + TypeNode resultType = t.getType(); + TypeNode combineType = nm->mkFunctionType(integerType, resultType); + TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType); + // skolem functions + Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A); + Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A); + Node unionDisjoint = sm->mkSkolemFunction( + SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A); + Node combine = sm->mkSkolemFunction( + SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A}); + + BoundVarManager* bvm = nm->getBoundVarManager(); + Node i = bvm->mkBoundVar(node, "i", nm->integerType()); + Node iList = nm->mkNode(BOUND_VAR_LIST, i); + Node iMinusOne = nm->mkNode(MINUS, i, one); + Node uf_i = nm->mkNode(APPLY_UF, uf, i); + Node combine_0 = nm->mkNode(APPLY_UF, combine, zero); + Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne); + Node combine_i = nm->mkNode(APPLY_UF, combine, i); + Node combine_n = nm->mkNode(APPLY_UF, combine, n); + Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero); + Node unionDisjoint_iMinusOne = + nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne); + Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i); + Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n); + Node combine_0_equal = combine_0.eqNode(t); + Node combine_i_equal = + combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne)); + Node unionDisjoint_0_equal = + unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType))); + Node singleton = nm->mkBag(elementType, uf_i, one); + + Node unionDisjoint_i_equal = unionDisjoint_i.eqNode( + nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne)); + Node interval_i = + nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n)); + + Node body_i = + nm->mkNode(IMPLIES, + interval_i, + nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal)); + Node forAll_i = + quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i); + Node nonNegative = nm->mkNode(GEQ, n, zero); + Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n); + asserts.push_back(forAll_i); + asserts.push_back(combine_0_equal); + asserts.push_back(unionDisjoint_0_equal); + asserts.push_back(unionDisjoint_n_equal); + asserts.push_back(nonNegative); + return combine_n; + } + return Node::null(); +} + +} // namespace bags +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h new file mode 100644 index 000000000..11f091f94 --- /dev/null +++ b/src/theory/bags/bag_reduction.h @@ -0,0 +1,77 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * bag reduction. + */ + +#ifndef CVC5__BAG_REDUCTION_H +#define CVC5__BAG_REDUCTION_H + +#include + +#include "cvc5_private.h" +#include "smt/env_obj.h" +#include "theory/bags/inference_manager.h" + +namespace cvc5 { +namespace theory { +namespace bags { + +/** + * class for bag reductions + */ +class BagReduction : EnvObj +{ + public: + BagReduction(Env& env); + ~BagReduction(); + + /** + * @param node a term of the form (bag.fold f t A) where + * f: (-> T1 T2 T2) is a binary operation + * t: T2 is the initial value + * A: (Bag T1) is a bag + * @param asserts a list of assertions generated by this reduction + * @return the reduction term (combine n) such that + * (and + * (forall ((i Int)) + * (let ((iMinusOne (- i 1))) + * (let ((uf_i (uf i))) + * (=> + * (and (>= i 1) (<= i n)) + * (and + * (= (combine i) (f uf_i (combine iMinusOne))) + * (= + * (unionDisjoint i) + * (bag.union_disjoint + * (bag uf_i 1) + * (unionDisjoint iMinusOne)))))))) + * (= (combine 0) t) + * (= (unionDisjoint 0) (as bag.empty (Bag T1))) + * (= A (unionDisjoint n)) + * (>= n 0)) + * where + * n: Int is the cardinality of bag A + * uf:Int -> T1 is an uninterpreted function that represents elements of A + * combine: Int -> T2 is an uninterpreted function + * unionDisjoint: Int -> (Bag T1) is an uninterpreted function + */ + Node reduceFoldOperator(Node node, std::vector& asserts); + + private: +}; + +} // namespace bags +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__BAG_REDUCTION_H */ diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index b8f3b80c9..766731806 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_FROM_SET: response = rewriteFromSet(n); break; case BAG_TO_SET: response = rewriteToSet(n); break; case BAG_MAP: response = postRewriteMap(n); break; + case BAG_FOLD: response = postRewriteFold(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } } @@ -560,6 +561,45 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const default: return BagsRewriteResponse(n, Rewrite::NONE); } } + +BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const +{ + Assert(n.getKind() == kind::BAG_FOLD); + Node f = n[0]; + Node t = n[1]; + Node bag = n[2]; + if (bag.isConst()) + { + Node value = NormalForm::evaluateBagFold(n); + return BagsRewriteResponse(value, Rewrite::FOLD_CONST); + } + Kind k = bag.getKind(); + switch (k) + { + case BAG_MAKE: + { + if (bag[1].isConst() && bag[1].getConst() > Rational(0)) + { + // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0 + Node value = NormalForm::evaluateBagFold(n); + return BagsRewriteResponse(value, Rewrite::FOLD_BAG); + } + break; + } + case BAG_UNION_DISJOINT: + { + // (bag.fold f t (bag.union_disjoint A B)) = + // (bag.fold f (bag.fold f t A) B) where A < B to break symmetry + Node A = bag[0] < bag[1] ? bag[0] : bag[1]; + Node B = bag[0] < bag[1] ? bag[1] : bag[0]; + Node foldA = d_nm->mkNode(BAG_FOLD, f, t, A); + Node fold = d_nm->mkNode(BAG_FOLD, f, foldA, B); + return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT); + } + default: return BagsRewriteResponse(n, Rewrite::NONE); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index a938b3bd4..d666982a7 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -222,6 +222,16 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse postRewriteMap(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.fold f t (as bag.empty (Bag T1))) = t + * - (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, where n > 0 + * - (bag.fold f t (bag.union_disjoint A B)) = + * (bag.fold f (bag.fold f t A) B) where A < B to break symmetry + * where f: T1 -> T2 -> T2 + */ + BagsRewriteResponse postRewriteFold(const TNode& n) const; + private: /** Reference to the rewriter statistics. */ NodeManager* d_nm; diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index a5c6e75bf..5e4119fa1 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -76,6 +76,14 @@ operator BAG_CHOOSE 1 "return an element in the bag given as a parameter # of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2). operator BAG_MAP 2 "bag map function" +# bag.fold operator combines elements of a bag into a single value. +# (bag.fold f t B) folds the elements of bag B starting with term t and using +# the combining function f. +# f: a binary operation of type (-> T1 T2 T2) +# t: an initial value of type T2 +# B: a bag of type (Bag T1) +operator BAG_FOLD 3 "bag fold operator" + typerule BAG_UNION_MAX ::cvc5::theory::bags::BinaryOperatorTypeRule typerule BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule typerule BAG_INTER_MIN ::cvc5::theory::bags::BinaryOperatorTypeRule @@ -93,6 +101,7 @@ typerule BAG_IS_SINGLETON ::cvc5::theory::bags::IsSingletonTypeRule typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule +typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule construle BAG_MAKE ::cvc5::theory::bags::BagMakeTypeRule diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index 12bf513b5..9a510c6f5 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -110,6 +110,7 @@ Node NormalForm::evaluate(TNode n) case BAG_FROM_SET: return evaluateFromSet(n); case BAG_TO_SET: return evaluateToSet(n); case BAG_MAP: return evaluateBagMap(n); + case BAG_FOLD: return evaluateBagFold(n); default: break; } Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n @@ -169,8 +170,6 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n, std::map NormalForm::getBagElements(TNode n) { - Assert(n.isConst()) << "node " << n << " is not in a normal form" - << std::endl; std::map elements; if (n.getKind() == BAG_EMPTY) { @@ -692,6 +691,41 @@ Node NormalForm::evaluateBagMap(TNode n) return ret; } +Node NormalForm::evaluateBagFold(TNode n) +{ + Assert(n.getKind() == BAG_FOLD); + + // Examples + // -------- + // minimum string + // - (bag.fold + // ((lambda ((x String) (y String)) (ite (str.< x y) x y)) + // "" + // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) + // = "a" + + Node f = n[0]; // combining function + Node ret = n[1]; // initial value + Node A = n[2]; // bag + std::map elements = NormalForm::getBagElements(A); + + std::map::iterator it = elements.begin(); + NodeManager* nm = NodeManager::currentNM(); + while (it != elements.end()) + { + // apply the combination function n times, where n is the multiplicity + Rational count = it->second; + Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl; + while (!count.isZero()) + { + ret = nm->mkNode(APPLY_UF, f, it->first, ret); + count = count - 1; + } + ++it; + } + return ret; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index 8ceee2881..5275678ff 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -75,6 +75,12 @@ class NormalForm static Node constructBagFromElements(TypeNode t, const std::map& elements); + /** + * @param n has the form (bag.fold f t A) where A is a constant bag + * @return a single value which is the result of the fold + */ + static Node evaluateBagFold(TNode n); + private: /** * a high order helper function that return a constant bag that is the result diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index 896c4f251..1a8f8f849 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -38,6 +38,9 @@ const char* toString(Rewrite r) case Rewrite::EQ_REFL: return "EQ_REFL"; case Rewrite::EQ_SYM: return "EQ_SYM"; case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON"; + case Rewrite::FOLD_BAG: return "FOLD_BAG"; + case Rewrite::FOLD_CONST: return "FOLD_CONST"; + case Rewrite::FOLD_UNION_DISJOINT: return "FOLD_UNION_DISJOINT"; case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES"; case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT"; case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index c5050ea72..0b7188599 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -42,6 +42,9 @@ enum class Rewrite : uint32_t EQ_REFL, EQ_SYM, FROM_SINGLETON, + FOLD_BAG, + FOLD_CONST, + FOLD_UNION_DISJOINT, IDENTICAL_NODES, INTERSECTION_EMPTY_LEFT, INTERSECTION_EMPTY_RIGHT, diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 4dffbdb00..68bdb7b1b 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -20,6 +20,7 @@ #include "proof/proof_checker.h" #include "smt/logic_exception.h" #include "theory/bags/normal_form.h" +#include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/rewriter.h" #include "theory/theory_model.h" #include "util/rational.h" @@ -39,7 +40,8 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation) d_statistics(), d_rewriter(&d_statistics.d_rewrites), d_termReg(env, d_state, d_im), - d_solver(env, d_state, d_im, d_termReg) + d_solver(env, d_state, d_im, d_termReg), + d_bagReduction(env) { // use the official theory state and inference manager objects d_theoryState = &d_state; @@ -87,6 +89,18 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) { case kind::BAG_CHOOSE: return expandChooseOperator(atom, lems); case kind::BAG_CARD: return expandCardOperator(atom, lems); + case kind::BAG_FOLD: + { + std::vector asserts; + Node ret = d_bagReduction.reduceFoldOperator(atom, asserts); + NodeManager* nm = NodeManager::currentNM(); + Node andNode = nm->mkNode(AND, asserts); + d_im.lemma(andNode, InferenceId::BAGS_FOLD); + Trace("bags::ppr") << "reduce(" << atom << ") = " << ret + << " such that:" << std::endl + << asserts << std::endl; + return TrustNode::mkTrustRewrite(atom, ret, nullptr); + } default: return TrustNode::null(); } } @@ -131,9 +145,9 @@ TrustNode TheoryBags::expandChooseOperator(const Node& node, return TrustNode::mkTrustRewrite(node, ret, nullptr); } -TrustNode TheoryBags::expandCardOperator(TNode n, - std::vector& vector) +TrustNode TheoryBags::expandCardOperator(TNode n, std::vector&) { + Assert(n.getKind() == BAG_CARD); if (d_env.getLogicInfo().isHigherOrder()) { // (bag.card A) = (bag.count 1 (bag.map (lambda ((x E)) 1) A)), diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h index fd28482d4..1a8af780e 100644 --- a/src/theory/bags/theory_bags.h +++ b/src/theory/bags/theory_bags.h @@ -18,6 +18,7 @@ #ifndef CVC5__THEORY__BAGS__THEORY_BAGS_H #define CVC5__THEORY__BAGS__THEORY_BAGS_H +#include "theory/bags/bag_reduction.h" #include "theory/bags/bag_solver.h" #include "theory/bags/bags_rewriter.h" #include "theory/bags/bags_statistics.h" @@ -112,6 +113,9 @@ class TheoryBags : public Theory /** the main solver for bags */ BagSolver d_solver; + /** bag reduction */ + BagReduction d_bagReduction; + void eqNotifyNewClass(TNode n); void eqNotifyMerge(TNode n1, TNode n2); void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index 2623f3ed7..fe81fadf5 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -327,6 +327,57 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_FOLD); + TypeNode functionType = n[0].getType(check); + TypeNode initialValueType = n[1].getType(check); + TypeNode bagType = n[2].getType(check); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, + "bag.fold operator expects a bag in the third argument, " + "a non-bag is found"); + } + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T2 T2) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + TypeNode rangeType = functionType.getRangeType(); + if (!(argTypes.size() == 2 && argTypes[0] == elementType + && argTypes[1] == rangeType)) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " T2 T2). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + if (rangeType != initialValueType) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects an initial value of type " + << rangeType << ". Found a term of type '" << initialValueType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + TypeNode retType = n[0].getType().getRangeType(); + return retType; +} + Cardinality BagsProperties::computeCardinality(TypeNode type) { return Cardinality::INTEGERS; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index d7b8b2737..fa2f78313 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -132,6 +132,15 @@ struct BagMapTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagMapTypeRule */ +/** + * Type rule for (bag.fold f t A) to make sure f is a binary operation of type + * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1) + */ +struct BagFoldTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFoldTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index 82ae674e2..56d2f0500 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -118,6 +118,7 @@ const char* toString(InferenceId i) case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE"; case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL"; case InferenceId::BAGS_MAP: return "BAGS_MAP"; + case InferenceId::BAGS_FOLD: return "BAGS_FOLD"; case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT"; case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA: diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index ad879d7ab..d98d3ff25 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -180,6 +180,7 @@ enum class InferenceId BAGS_DIFFERENCE_REMOVE, BAGS_DUPLICATE_REMOVAL, BAGS_MAP, + BAGS_FOLD, // ---------------------------------- end bags theory // ---------------------------------- bitvector theory diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index cf114711a..4169036ba 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1606,6 +1606,7 @@ set(regress_1_tests regress1/bags/duplicate_removal1.smt2 regress1/bags/duplicate_removal2.smt2 regress1/bags/emptybag1.smt2 + regress1/bags/fold1.smt2 regress1/bags/fuzzy1.smt2 regress1/bags/fuzzy2.smt2 regress1/bags/fuzzy3.smt2 @@ -2820,6 +2821,8 @@ set(regression_disabled_tests regress0/tptp/SYN075+1.p regress0/uf/iso_icl_repgen004.smtv1.smt2 ### + # takes around 30 sec + regress1/bags/fold2.smt2 regress1/bug472.smt2 regress1/datatypes/non-simple-rec-set.smt2 # results in an assertion failure (see issue #1650). diff --git a/test/regress/regress1/bags/fold1.smt2 b/test/regress/regress1/bags/fold1.smt2 new file mode 100644 index 000000000..73caedae5 --- /dev/null +++ b/test/regress/regress1/bags/fold1.smt2 @@ -0,0 +1,10 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) +(define-fun plus ((x Int) (y Int)) Int (+ x y)) +(declare-fun A () (Bag Int)) +(declare-fun sum () Int) +(assert (= sum (bag.fold plus 1 A))) +(assert (= sum 10)) +(check-sat) diff --git a/test/regress/regress1/bags/fold2.smt2 b/test/regress/regress1/bags/fold2.smt2 new file mode 100644 index 000000000..9863a11c6 --- /dev/null +++ b/test/regress/regress1/bags/fold2.smt2 @@ -0,0 +1,15 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) +(set-option :strings-exp true) +(define-fun min ((x String) (y String)) String (ite (str.< x y) x y)) +(declare-fun A () (Bag String)) +(declare-fun x () String) +(declare-fun minimum () String) +(assert (= minimum (bag.fold min "zzz" A))) +(assert (str.< "aaa" minimum )) +(assert (str.< minimum "zzz")) +(assert (distinct x minimum)) +(assert (= (bag.count x A) 2)) +(check-sat) diff --git a/test/unit/theory/theory_bags_rewriter_white.cpp b/test/unit/theory/theory_bags_rewriter_white.cpp index ee1e89448..ff98c308a 100644 --- a/test/unit/theory/theory_bags_rewriter_white.cpp +++ b/test/unit/theory/theory_bags_rewriter_white.cpp @@ -750,7 +750,7 @@ TEST_F(TestTheoryWhiteBagsRewriter, map) Node empty = d_nodeManager->mkConst(String("")); Node xString = d_nodeManager->mkBoundVar("x", d_nodeManager->stringType()); - Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, xString); + Node bound = d_nodeManager->mkNode(BOUND_VAR_LIST, xString); Node lambda = d_nodeManager->mkNode(LAMBDA, bound, empty); // (bag.map (lambda ((x U)) t) (as bag.empty (Bag String)) = @@ -800,5 +800,62 @@ TEST_F(TestTheoryWhiteBagsRewriter, map) ASSERT_TRUE(rewritten3 == unionDisjointMapK1K2); } +TEST_F(TestTheoryWhiteBagsRewriter, fold) +{ + TypeNode bagIntegerType = + d_nodeManager->mkBagType(d_nodeManager->integerType()); + Node emptybag = d_nodeManager->mkConst(EmptyBag(bagIntegerType)); + Node zero = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0)); + Node one = d_nodeManager->mkConst(CONST_RATIONAL, Rational(1)); + Node ten = d_nodeManager->mkConst(CONST_RATIONAL, Rational(10)); + Node n = d_nodeManager->mkConst(CONST_RATIONAL, Rational(2)); + Node x = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType()); + Node y = d_nodeManager->mkBoundVar("y", d_nodeManager->integerType()); + Node xy = d_nodeManager->mkNode(BOUND_VAR_LIST, x, y); + Node sum = d_nodeManager->mkNode(PLUS, x, y); + + // f(x,y) = 0 for all x, y + Node f = d_nodeManager->mkNode(LAMBDA, xy, zero); + Node node1 = d_nodeManager->mkNode(BAG_FOLD, f, one, emptybag); + RewriteResponse response1 = d_rewriter->postRewrite(node1); + ASSERT_TRUE(response1.d_node == one + && response1.d_status == REWRITE_AGAIN_FULL); + + // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, where n > 0 + f = d_nodeManager->mkNode(LAMBDA, xy, sum); + Node xSkolem = d_nodeManager->getSkolemManager()->mkDummySkolem( + "x", d_nodeManager->integerType()); + Node bag = d_nodeManager->mkBag(d_nodeManager->integerType(), xSkolem, n); + Node node2 = d_nodeManager->mkNode(BAG_FOLD, f, one, bag); + Node apply_f_once = d_nodeManager->mkNode(APPLY_UF, f, xSkolem, one); + Node apply_f_twice = + d_nodeManager->mkNode(APPLY_UF, f, xSkolem, apply_f_once); + RewriteResponse response2 = d_rewriter->postRewrite(node2); + ASSERT_TRUE(response2.d_node == apply_f_twice + && response2.d_status == REWRITE_AGAIN_FULL); + + // (bag.fold (lambda ((x Int)(y Int)) (+ x y)) 1 (bag 10 2)) = 21 + bag = d_nodeManager->mkBag(d_nodeManager->integerType(), ten, n); + Node node3 = d_nodeManager->mkNode(BAG_FOLD, f, one, bag); + Node result3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(21)); + Node response3 = Rewriter::rewrite(node3); + ASSERT_TRUE(response3 == result3); + + // (bag.fold f t (bag.union_disjoint A B)) = + // (bag.fold f (bag.fold f t A) B) where A < B to break symmetry + + Node A = + d_nodeManager->getSkolemManager()->mkDummySkolem("A", bagIntegerType); + Node B = + d_nodeManager->getSkolemManager()->mkDummySkolem("B", bagIntegerType); + Node disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B); + Node node4 = d_nodeManager->mkNode(BAG_FOLD, f, one, disjoint); + Node foldA = d_nodeManager->mkNode(BAG_FOLD, f, one, A); + Node fold = d_nodeManager->mkNode(BAG_FOLD, f, foldA, B); + RewriteResponse response4 = d_rewriter->postRewrite(node4); + ASSERT_TRUE(response4.d_node == fold + && response2.d_status == REWRITE_AGAIN_FULL); +} + } // namespace test } // namespace cvc5 -- 2.30.2