From: mudathirmahgoub Date: Thu, 3 Feb 2022 15:55:16 +0000 (-0600) Subject: Add table.product operator (#8020) X-Git-Tag: cvc5-1.0.0~461 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=3eb47718f6e24cc719094732b639e1d8b73012a4;p=cvc5.git Add table.product operator (#8020) --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 226f4632d..87ba2bb94 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -661,6 +661,8 @@ libcvc5_add_sources( theory/datatypes/theory_datatypes_utils.h theory/datatypes/tuple_project_op.cpp theory/datatypes/tuple_project_op.h + theory/datatypes/tuple_utils.cpp + theory/datatypes/tuple_utils.h theory/datatypes/type_enumerator.cpp theory/datatypes/type_enumerator.h theory/decision_manager.cpp @@ -997,6 +999,7 @@ libcvc5_add_sources( theory/sets/inference_manager.cpp theory/sets/inference_manager.h theory/sets/normal_form.h + theory/sets/rels_utils.cpp theory/sets/rels_utils.h theory/sets/singleton_op.cpp theory/sets/singleton_op.h diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index d46b8a971..54174aec4 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -314,6 +314,7 @@ const static std::unordered_map s_kinds{ {BAG_MAP, cvc5::Kind::BAG_MAP}, {BAG_FILTER, cvc5::Kind::BAG_FILTER}, {BAG_FOLD, cvc5::Kind::BAG_FOLD}, + {TABLE_PRODUCT, cvc5::Kind::TABLE_PRODUCT}, /* Strings ------------------------------------------------------------- */ {STRING_CONCAT, cvc5::Kind::STRING_CONCAT}, {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP}, @@ -627,6 +628,7 @@ const static std::unordered_map {cvc5::Kind::BAG_MAP, BAG_MAP}, {cvc5::Kind::BAG_FILTER, BAG_FILTER}, {cvc5::Kind::BAG_FOLD, BAG_FOLD}, + {cvc5::Kind::TABLE_PRODUCT, TABLE_PRODUCT}, /* Strings --------------------------------------------------------- */ {cvc5::Kind::STRING_CONCAT, STRING_CONCAT}, {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 1609fb221..112b53eb7 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2572,6 +2572,17 @@ enum Kind : int32_t * - `Solver::mkTerm(Kind kind, const std::vector& children) const` */ BAG_FOLD, + /** + * Table cross product. + * + * Parameters: + * - 1..2: Terms of bag sort + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + TABLE_PRODUCT, /* Strings --------------------------------------------------------------- */ diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 1c3ea84df..3352bde1b 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -629,6 +629,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_MAP, "bag.map"); addOperator(api::BAG_FILTER, "bag.filter"); addOperator(api::BAG_FOLD, "bag.fold"); + addOperator(api::TABLE_PRODUCT, "table.product"); } if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) { defineType("String", d_solver->getStringSort(), true, true); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index dd74f0071..420c176f7 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1125,6 +1125,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_MAP: return "bag.map"; case kind::BAG_FILTER: return "bag.filter"; case kind::BAG_FOLD: return "bag.fold"; + case kind::TABLE_PRODUCT: return "table.product"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index ed4b501f3..dbd2bbc29 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -78,6 +78,7 @@ void BagSolver::checkBasicOperations() case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break; case kind::BAG_FILTER: checkFilter(n); break; case kind::BAG_MAP: checkMap(n); break; + case kind::TABLE_PRODUCT: checkProduct(n); break; default: break; } it++; @@ -303,6 +304,29 @@ void BagSolver::checkFilter(Node n) } } +void BagSolver::checkProduct(Node n) +{ + Assert(n.getKind() == TABLE_PRODUCT); + const set& elementsA = d_state.getElements(n[0]); + const set& elementsB = d_state.getElements(n[1]); + for (const Node& e1 : elementsA) + { + for (const Node& e2 : elementsB) + { + InferInfo i = d_ig.productUp( + n, d_state.getRepresentative(e1), d_state.getRepresentative(e2)); + d_im.lemmaTheoryInference(&i); + } + } + + std::set elements = d_state.getElements(n); + for (const Node& e : elements) + { + InferInfo i = d_ig.productDown(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index fca72b22e..eb578aafd 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -98,6 +98,8 @@ class BagSolver : protected EnvObj void checkMap(Node n); /** apply inference rules for filter operator */ void checkFilter(Node n); + /** apply inference rules for product operator */ + void checkProduct(Node n); /** The solver state object */ SolverState& d_state; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 396e33557..031910cdd 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -92,6 +92,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_MAP: response = postRewriteMap(n); break; case BAG_FILTER: response = postRewriteFilter(n); break; case BAG_FOLD: response = postRewriteFold(n); break; + case TABLE_PRODUCT: response = postRewriteProduct(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } } @@ -654,6 +655,20 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const } return BagsRewriteResponse(n, Rewrite::NONE); } + +BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const +{ + Assert(n.getKind() == TABLE_PRODUCT); + TypeNode tableType = n.getType(); + Node empty = d_nm->mkConst(EmptyBag(tableType)); + if (n[0].getKind() == BAG_EMPTY || n[1].getKind() == BAG_EMPTY) + { + return BagsRewriteResponse(empty, Rewrite::PRODUCT_EMPTY); + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 3e5b69a1c..f05766c53 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -247,6 +247,15 @@ class BagsRewriter : public TheoryRewriter * where f: T1 -> T2 -> T2 */ BagsRewriteResponse postRewriteFold(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.product A (as bag.empty T2)) = (as bag.empty T) + * - (bag.product (as bag.empty T2)) = (f t ... (f t (f t x))) n times, where n > 0 + * - (bag.fold f t (bag.union_disjoint A B)) = + * (bag.fold f (bag.fold f t A) B) where A < B to break symmetry + * where f: T1 -> T2 -> T2 + */ + BagsRewriteResponse postRewriteProduct(const TNode& n)const; private: /** Reference to the rewriter statistics. */ diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index 39987ce9d..6514d8d3f 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -14,13 +14,17 @@ */ #include "bags_utils.h" +#include "expr/dtype.h" +#include "expr/dtype_cons.h" #include "expr/emptybag.h" #include "smt/logic_exception.h" +#include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" #include "theory/type_enumerator.h" #include "util/rational.h" using namespace cvc5::kind; +using namespace cvc5::theory::datatypes; namespace cvc5 { namespace theory { @@ -136,6 +140,7 @@ Node BagsUtils::evaluate(TNode n) case BAG_MAP: return evaluateBagMap(n); case BAG_FILTER: return evaluateBagFilter(n); case BAG_FOLD: return evaluateBagFold(n); + case TABLE_PRODUCT: return evaluateProduct(n); default: break; } Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n @@ -778,6 +783,52 @@ Node BagsUtils::evaluateBagFold(TNode n) return ret; } +Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2) +{ + Assert(n.getKind() == TABLE_PRODUCT); + Node A = n[0]; + Node B = n[1]; + TypeNode typeA = A.getType().getBagElementType(); + TypeNode typeB = B.getType().getBagElementType(); + Assert(e1.getType().isSubtypeOf(typeA)); + Assert(e2.getType().isSubtypeOf(typeB)); + + TypeNode productTupleType = n.getType().getBagElementType(); + Node tuple = TupleUtils::concatTuples(productTupleType, e1, e2); + return tuple; +} + +Node BagsUtils::evaluateProduct(TNode n) +{ + Assert(n.getKind() == TABLE_PRODUCT); + + // Examples + // -------- + // + // - (table.product (bag (tuple "a") 4) (bag (tuple true) 5)) = + // (bag (tuple "a" true) 20 + + Node A = n[0]; + Node B = n[1]; + + std::map elementsA = BagsUtils::getBagElements(A); + std::map elementsB = BagsUtils::getBagElements(B); + + std::map elements; + + for (const auto& [a, countA] : elementsA) + { + for (const auto& [b, countB] : elementsB) + { + Node element = constructProductTuple(n, a, b); + elements[element] = countA * countB; + } + } + + Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements); + return ret; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h index 61473a023..3b6311ded 100644 --- a/src/theory/bags/bags_utils.h +++ b/src/theory/bags/bags_utils.h @@ -17,8 +17,8 @@ #include "cvc5_private.h" -#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H -#define CVC5__THEORY__BAGS__NORMAL_FORM_H +#ifndef CVC5__THEORY__BAGS__UTILS_H +#define CVC5__THEORY__BAGS__UTILS_H namespace cvc5 { namespace theory { @@ -94,6 +94,21 @@ class BagsUtils */ static Node evaluateBagFilter(TNode n); + /** + * @param n of the form (table.product A B) where A , B of types (Bag T1), + * (Bag T2) respectively. + * @param e1 a tuple of type T1 of the form (tuple a1 ... an) + * @param e2 a tuple of type T2 of the form (tuple b1 ... bn) + * @return (tuple a1 ... an b1 ... bn) + */ + static Node constructProductTuple(TNode n, TNode e1, TNode e2); + + /** + * @param n of the form (table.product A B) where A, B are constants + * @return the evaluation of the cross product of A B + */ + static Node evaluateProduct(TNode n); + private: /** * a high order helper function that return a constant bag that is the result @@ -220,4 +235,4 @@ class BagsUtils } // namespace theory } // namespace cvc5 -#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */ +#endif /* CVC5__THEORY__BAGS__UTILS_H */ diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 3cc2936fb..aa5bf74d8 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -17,15 +17,19 @@ #include "expr/attribute.h" #include "expr/bound_var_manager.h" +#include "expr/dtype_cons.h" #include "expr/emptybag.h" #include "expr/skolem_manager.h" +#include "theory/bags/bags_utils.h" #include "theory/bags/inference_manager.h" #include "theory/bags/solver_state.h" +#include "theory/datatypes/tuple_utils.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/uf/equality_engine.h" #include "util/rational.h" using namespace cvc5::kind; +using namespace cvc5::theory::datatypes; namespace cvc5 { namespace theory { @@ -563,6 +567,60 @@ InferInfo InferenceGenerator::filterUpwards(Node n, Node e) return inferInfo; } +InferInfo InferenceGenerator::productUp(Node n, Node e1, Node e2) +{ + Assert(n.getKind() == TABLE_PRODUCT); + Node A = n[0]; + Node B = n[1]; + Node tuple = BagsUtils::constructProductTuple(n, e1, e2); + + InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_UP); + + Node countA = getMultiplicityTerm(e1, A); + Node countB = getMultiplicityTerm(e2, B); + + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(tuple, skolem); + + Node multiply = d_nm->mkNode(MULT, countA, countB); + inferInfo.d_conclusion = count.eqNode(multiply); + + return inferInfo; +} + +InferInfo InferenceGenerator::productDown(Node n, Node e) +{ + Assert(n.getKind() == TABLE_PRODUCT); + Assert(e.getType().isSubtypeOf(n.getType().getBagElementType())); + + Node A = n[0]; + Node B = n[1]; + + TypeNode tupleBType = B.getType().getBagElementType(); + TypeNode tupleAType = A.getType().getBagElementType(); + size_t tupleALength = tupleAType.getTupleLength(); + size_t productTupleLength = n.getType().getBagElementType().getTupleLength(); + + std::vector elements = TupleUtils::getTupleElements(e); + Node a = TupleUtils::constructTupleFromElements( + tupleAType, elements, 0, tupleALength - 1); + Node b = TupleUtils::constructTupleFromElements( + tupleBType, elements, tupleALength, productTupleLength - 1); + + InferInfo inferInfo(d_im, InferenceId::TABLES_PRODUCT_DOWN); + + Node countA = getMultiplicityTerm(a, A); + Node countB = getMultiplicityTerm(b, B); + + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + + Node multiply = d_nm->mkNode(MULT, countA, countB); + inferInfo.d_conclusion = count.eqNode(multiply); + + return inferInfo; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index 3d74dbaa2..ed6122356 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -271,7 +271,7 @@ class InferenceGenerator * (bag.member e skolem) * (and * (p e) - * (= (bag.count e skolem) (bag.count A))) + * (= (bag.count e skolem) (bag.count e A))) * where skolem is a variable equals (bag.filter p A) */ InferInfo filterDownwards(Node n, Node e); @@ -290,6 +290,29 @@ class InferenceGenerator */ InferInfo filterUpwards(Node n, Node e); + /** + * @param n is a (table.product A B) where A, B are bags of tuples + * @param e1 an element of the form (tuple a1 ... am) + * @param e2 an element of the form (tuple b1 ... bn) + * @return an inference that represents the following + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count e1 A) (bag.count e2 B))) + * where skolem is a variable equals (bag.product A B) + */ + InferInfo productUp(Node n, Node e1, Node e2); + + /** + * @param n is a (table.product A B) where A, B are bags of tuples + * @param e an element of the form (tuple a1 ... am b1 ... bn) + * @return an inference that represents the following + * (= + * (bag.count (tuple a1 ... am b1 ... bn) skolem) + * (* (bag.count (tuple a1 ... am A) (bag.count (tuple b1 ... bn) B))) + * where skolem is a variable equals (bag.product A B) + */ + InferInfo productDown(Node n, Node e); + /** * @param element of type T * @param bag of type (bag T) diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 7d995dd7b..345b71e9b 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -113,4 +113,10 @@ typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule construle BAG_MAKE ::cvc5::theory::bags::BagMakeTypeRule + +# bag.product operator returns the cross product of two tables +operator TABLE_PRODUCT 2 "table cross product" + +typerule TABLE_PRODUCT ::cvc5::theory::bags::TableProductTypeRule + endtheory diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index 9bd0c3a86..576f1245c 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -56,6 +56,7 @@ const char* toString(Rewrite r) case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE"; case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT"; case Rewrite::MEMBER: return "MEMBER"; + case Rewrite::PRODUCT_EMPTY: return "PRODUCT_EMPTY"; case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION"; case Rewrite::REMOVE_MIN: return "REMOVE_MIN"; case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index e1ef38c4b..e7f2113f9 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -60,6 +60,7 @@ enum class Rewrite : uint32_t MAP_BAG_MAKE, MAP_UNION_DISJOINT, MEMBER, + PRODUCT_EMPTY, REMOVE_FROM_UNION, REMOVE_MIN, REMOVE_RETURN_LEFT, diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index 689b0e208..b0c79fb1d 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -449,6 +449,51 @@ TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::TABLE_PRODUCT); + Node A = n[0]; + Node B = n[1]; + TypeNode typeA = n[0].getType(check); + TypeNode typeB = n[1].getType(check); + + if (check && !(typeA.isBag() && typeB.isBag())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects two bags. " + << "Found two terms of types '" << typeA << "' and '" << typeB + << "' respectively."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode elementAType = typeA.getBagElementType(); + TypeNode elementBType = typeB.getBagElementType(); + + if (check && !(elementAType.isTuple() && elementBType.isTuple())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects two tables (bags of tuples). " + << "Found two terms of types '" << typeA << "' and '" << typeB + << "' respectively."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + std::vector productTupleTypes; + std::vector tupleATypes = elementAType.getTupleTypes(); + std::vector tupleBTypes = elementBType.getTupleTypes(); + + productTupleTypes.insert( + productTupleTypes.end(), tupleATypes.begin(), tupleATypes.end()); + productTupleTypes.insert( + productTupleTypes.end(), tupleBTypes.begin(), tupleBTypes.end()); + + TypeNode retTupleType = nodeManager->mkTupleType(productTupleTypes); + TypeNode retType = nodeManager->mkBagType(retTupleType); + return retType; +} + Cardinality BagsProperties::computeCardinality(TypeNode type) { return Cardinality::INTEGERS; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 76c179a62..8673f7296 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -159,6 +159,15 @@ struct BagFoldTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagFoldTypeRule */ +/** + * Type rule for (table.product A B) to make sure A,B are bags of tuples, + * and get the type of the cross product + */ +struct TableProductTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFoldTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp new file mode 100644 index 000000000..d691b3831 --- /dev/null +++ b/src/theory/datatypes/tuple_utils.cpp @@ -0,0 +1,123 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds, Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utility functions for data types. + */ + +#include "tuple_utils.h" + +#include "expr/dtype.h" +#include "expr/dtype_cons.h" + +using namespace cvc5::kind; + +namespace cvc5 { +namespace theory { +namespace datatypes { + +Node TupleUtils::nthElementOfTuple(Node tuple, int n_th) +{ + if (tuple.getKind() == APPLY_CONSTRUCTOR) + { + return tuple[n_th]; + } + TypeNode tn = tuple.getType(); + const DType& dt = tn.getDType(); + return NodeManager::currentNM()->mkNode( + APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, n_th), tuple); +} + +std::vector TupleUtils::getTupleElements(Node tuple) +{ + Assert(tuple.getType().isTuple()); + size_t tupleLength = tuple.getType().getTupleLength(); + std::vector elements; + for (size_t i = 0; i < tupleLength; i++) + { + elements.push_back(TupleUtils::nthElementOfTuple(tuple, i)); + } + return elements; +} + +std::vector TupleUtils::getTupleElements(Node tuple1, Node tuple2) +{ + std::vector elements; + std::vector elementsA = getTupleElements(tuple1); + size_t tuple1Length = tuple1.getType().getTupleLength(); + for (size_t i = 0; i < tuple1Length; i++) + { + elements.push_back(TupleUtils::nthElementOfTuple(tuple1, i)); + } + + size_t tuple2Length = tuple2.getType().getTupleLength(); + for (size_t i = 0; i < tuple2Length; i++) + { + elements.push_back(TupleUtils::nthElementOfTuple(tuple2, i)); + } + return elements; +} + +Node TupleUtils::constructTupleFromElements(TypeNode tupleType, + const std::vector& elements, + size_t start, + size_t end) +{ + std::vector tupleElements; + // add the constructor first + Node constructor = tupleType.getDType()[0].getConstructor(); + tupleElements.push_back(constructor); + // add the elements of the tuple + for (size_t i = start; i <= end; i++) + { + tupleElements.push_back(elements[i]); + } + NodeManager* nm = NodeManager::currentNM(); + Node tuple = nm->mkNode(APPLY_CONSTRUCTOR, tupleElements); + return tuple; +} + +Node TupleUtils::concatTuples(TypeNode tupleType, Node tuple1, Node tuple2) +{ + std::vector tupleElements; + // add the constructor first + Node constructor = tupleType.getDType()[0].getConstructor(); + tupleElements.push_back(constructor); + + // add the flattened concatenation of the two tuples e1, e2 + std::vector elements = getTupleElements(tuple1, tuple2); + tupleElements.insert(tupleElements.end(), elements.begin(), elements.end()); + + // construct the returned tuple + NodeManager* nm = NodeManager::currentNM(); + Node tuple = nm->mkNode(APPLY_CONSTRUCTOR, tupleElements); + return tuple; +} + +Node TupleUtils::reverseTuple(Node tuple) +{ + Assert(tuple.getType().isTuple()); + std::vector elements; + std::vector tuple_types = tuple.getType().getTupleTypes(); + std::reverse(tuple_types.begin(), tuple_types.end()); + TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types); + const DType& dt = tn.getDType(); + elements.push_back(dt[0].getConstructor()); + for (int i = tuple_types.size() - 1; i >= 0; --i) + { + elements.push_back(nthElementOfTuple(tuple, i)); + } + return NodeManager::currentNM()->mkNode(APPLY_CONSTRUCTOR, elements); +} + +} // namespace datatypes +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/datatypes/tuple_utils.h b/src/theory/datatypes/tuple_utils.h new file mode 100644 index 000000000..595052c72 --- /dev/null +++ b/src/theory/datatypes/tuple_utils.h @@ -0,0 +1,83 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds, Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utility functions for data types. + */ + +#ifndef CVC5__THEORY__TUPLE__UTILS_H +#define CVC5__THEORY__TUPLE__UTILS_H + +#include "expr/node.h" + +namespace cvc5 { +namespace theory { +namespace datatypes { + +class TupleUtils +{ + public: + /** + * @param tuple a node of tuple type + * @param n_th the index of the element to be extracted, and must satisfy the + * constraint 0 <= n_th < length of tuple. + * @return tuple element at index n_th + */ + static Node nthElementOfTuple(Node tuple, int n_th); + + /** + * @param tuple a tuple node of the form (tuple a_1 ... a_n) + * @return the vector [a_1, ... a_n] + */ + static std::vector getTupleElements(Node tuple); + + /** + * @param tuple1 a tuple node of the form (tuple a_1 ... a_n) + * @param tuple2 a tuple node of the form (tuple b_1 ... b_n) + * @return the vector [a_1, ... a_n, b_1, ... b_n] + */ + static std::vector getTupleElements(Node tuple1, Node tuple2); + + /** + * construct a tuple from a list of elements + * @param tupleType the type of the returned tuple + * @param elements the list of nodes + * @param start the index of the first element + * @param end the index of the last element + * @pre the elements from start to end should match the tuple type + * @return a tuple of constructed from elements from start to end + */ + static Node constructTupleFromElements(TypeNode tupleType, + const std::vector& elements, + size_t start, + size_t end); + + /** + * construct a flattened tuple from two tuples + * @param tupleType the type of the returned tuple + * @param tuple1 a tuple node of the form (tuple a_1 ... a_n) + * @param tuple2 a tuple node of the form (tuple b_1 ... b_n) + * @pre the elements of tuple1, tuple2 should match the tuple type + * @return (tuple a1 ... an b1 ... bn) + */ + static Node concatTuples(TypeNode tupleType, Node tuple1, Node tuple2); + + /** + * @param tuple a tuple node of the form (tuple e_1 ... e_n) + * @return the reverse of the argument, i.e., (tuple e_n ... e_1) + */ + static Node reverseTuple(Node tuple); +}; +} // namespace datatypes +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__TUPLE__UTILS_H */ diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index 240d6e293..791819f38 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -124,6 +124,8 @@ const char* toString(InferenceId i) case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP"; case InferenceId::BAGS_FOLD: return "BAGS_FOLD"; case InferenceId::BAGS_CARD: return "BAGS_CARD"; + case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP"; + case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN"; case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT"; case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA: diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 2fb3ae003..4301e0d16 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -186,6 +186,8 @@ enum class InferenceId BAGS_FILTER_UP, BAGS_FOLD, BAGS_CARD, + TABLES_PRODUCT_UP, + TABLES_PRODUCT_DOWN, // ---------------------------------- end bags theory // ---------------------------------- bitvector theory diff --git a/src/theory/sets/rels_utils.cpp b/src/theory/sets/rels_utils.cpp new file mode 100644 index 000000000..fdd9e6356 --- /dev/null +++ b/src/theory/sets/rels_utils.cpp @@ -0,0 +1,81 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utility functions for relations. + */ + +#include "rels_utils.h" + +#include "expr/dtype.h" +#include "expr/dtype_cons.h" +#include "theory/datatypes/tuple_utils.h" + +using namespace cvc5::theory::datatypes; + +namespace cvc5 { +namespace theory { +namespace sets { + +std::set RelsUtils::computeTC(const std::set& members, Node rel) +{ + std::set::iterator mem_it = members.begin(); + std::map ele_num_map; + std::set tc_rel_mem; + + while (mem_it != members.end()) + { + Node fst = TupleUtils::nthElementOfTuple(*mem_it, 0); + Node snd = TupleUtils::nthElementOfTuple(*mem_it, 1); + std::set traversed; + traversed.insert(fst); + computeTC(rel, members, fst, snd, traversed, tc_rel_mem); + mem_it++; + } + return tc_rel_mem; +} + +void RelsUtils::computeTC(Node rel, + const std::set& members, + Node a, + Node b, + std::set& traversed, + std::set& transitiveClosureMembers) +{ + transitiveClosureMembers.insert(constructPair(rel, a, b)); + if (traversed.find(b) != traversed.end()) + { + return; + } + traversed.insert(b); + std::set::iterator mem_it = members.begin(); + while (mem_it != members.end()) + { + Node new_fst = TupleUtils::nthElementOfTuple(*mem_it, 0); + Node new_snd = TupleUtils::nthElementOfTuple(*mem_it, 1); + if (b == new_fst) + { + computeTC(rel, members, a, new_snd, traversed, transitiveClosureMembers); + } + mem_it++; + } +} + +Node RelsUtils::constructPair(Node rel, Node a, Node b) +{ + const DType& dt = rel.getType().getSetElementType().getDType(); + return NodeManager::currentNM()->mkNode( + kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b); +} + +} // namespace sets +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/sets/rels_utils.h b/src/theory/sets/rels_utils.h index 46eeecd58..c070ad1da 100644 --- a/src/theory/sets/rels_utils.h +++ b/src/theory/sets/rels_utils.h @@ -10,90 +10,60 @@ * directory for licensing information. * **************************************************************************** * - * Extension to Sets theory. + * Utility functions for relations. */ #ifndef SRC_THEORY_SETS_RELS_UTILS_H_ #define SRC_THEORY_SETS_RELS_UTILS_H_ -#include "expr/dtype.h" -#include "expr/dtype_cons.h" #include "expr/node.h" namespace cvc5 { namespace theory { namespace sets { -class RelsUtils { +class RelsUtils +{ + public: + /** + * compute the transitive closure of a binary relation + * @param members constant nodes of type (Tuple E E) that are known to in the + * relation rel + * @param rel a binary relation of type (Set (Tuple E E)) + * @pre all members need to be constants + * @return the transitive closure of the relation + */ + static std::set computeTC(const std::set& members, Node rel); -public: + /** + * add all pairs (a, c) to the transitive closures where c is reachable from b + * in the transitive relation in a depth first search manner. + * @param rel a binary relation of type (Set (Tuple E E)) + * @param members constant nodes of type (Tuple E E) that are known to be in + * the relation rel + * @param a a node of type E where (a,b) is an element in the transitive + * closure + * @param b a node of type E where (a,b) is an element in the transitive + * closure + * @param traversed the set of members that have been visited so far + * @param transitiveClosureMembers members of the transitive closure computed + * so far + */ + static void computeTC(Node rel, + const std::set& members, + Node a, + Node b, + std::set& traversed, + std::set& transitiveClosureMembers); - // Assumption: the input rel_mem contains all constant pairs - static std::set< Node > computeTC( std::set< Node > rel_mem, Node rel ) { - std::set< Node >::iterator mem_it = rel_mem.begin(); - std::map< Node, int > ele_num_map; - std::set< Node > tc_rel_mem; - - while( mem_it != rel_mem.end() ) { - Node fst = nthElementOfTuple( *mem_it, 0 ); - Node snd = nthElementOfTuple( *mem_it, 1 ); - std::set< Node > traversed; - traversed.insert(fst); - computeTC(rel, rel_mem, fst, snd, traversed, tc_rel_mem); - mem_it++; - } - return tc_rel_mem; - } - - static void computeTC( Node rel, std::set< Node >& rel_mem, Node fst, - Node snd, std::set< Node >& traversed, std::set< Node >& tc_rel_mem ) { - tc_rel_mem.insert(constructPair(rel, fst, snd)); - if( traversed.find(snd) == traversed.end() ) { - traversed.insert(snd); - } else { - return; - } - - std::set< Node >::iterator mem_it = rel_mem.begin(); - while( mem_it != rel_mem.end() ) { - Node new_fst = nthElementOfTuple( *mem_it, 0 ); - Node new_snd = nthElementOfTuple( *mem_it, 1 ); - if( snd == new_fst ) { - computeTC(rel, rel_mem, fst, new_snd, traversed, tc_rel_mem); - } - mem_it++; - } - } - - static Node nthElementOfTuple( Node tuple, int n_th ) { - if( tuple.getKind() == kind::APPLY_CONSTRUCTOR ) { - return tuple[n_th]; - } - TypeNode tn = tuple.getType(); - const DType& dt = tn.getDType(); - return NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, n_th), tuple); - } - - static Node reverseTuple( Node tuple ) { - Assert(tuple.getType().isTuple()); - std::vector elements; - std::vector tuple_types = tuple.getType().getTupleTypes(); - std::reverse( tuple_types.begin(), tuple_types.end() ); - TypeNode tn = NodeManager::currentNM()->mkTupleType( tuple_types ); - const DType& dt = tn.getDType(); - elements.push_back(dt[0].getConstructor()); - for(int i = tuple_types.size() - 1; i >= 0; --i) { - elements.push_back( nthElementOfTuple(tuple, i) ); - } - return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements ); - } - static Node constructPair(Node rel, Node a, Node b) { - const DType& dt = rel.getType().getSetElementType().getDType(); - return NodeManager::currentNM()->mkNode( - kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b); - } - + /** + * construct a pair from two elements + * @param rel a node of type (Set (Tuple E E)) + * @param a a node of type E + * @param b a node of type E + * @return a tuple (tuple a b) + */ + static Node constructPair(Node rel, Node a, Node b); }; } // namespace sets } // namespace theory diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 49f8f053a..d6a52b76e 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -15,13 +15,17 @@ #include "theory/sets/theory_sets_rels.h" +#include "expr/dtype.h" +#include "expr/dtype_cons.h" #include "expr/skolem_manager.h" +#include "theory/datatypes/tuple_utils.h" #include "theory/sets/theory_sets.h" #include "theory/sets/theory_sets_private.h" #include "util/rational.h" using namespace std; using namespace cvc5::kind; +using namespace cvc5::theory::datatypes; namespace cvc5 { namespace theory { @@ -268,7 +272,7 @@ void TheorySetsRels::check(Theory::Effort level) std::vector tupleTypes = erType.getTupleTypes(); for (unsigned i = 0, tlen = erType.getTupleLength(); i < tlen; i++) { - Node element = RelsUtils::nthElementOfTuple(eqc_node, i); + Node element = TupleUtils::nthElementOfTuple(eqc_node, i); if (!element.isConst()) { makeSharedTerm(element, tupleTypes[i]); @@ -306,7 +310,7 @@ void TheorySetsRels::check(Theory::Effort level) unsigned int min_card = join_image_term[1].getConst().getNumerator().getUnsignedInt(); while( mem_rep_it != (*rel_mem_it).second.end() ) { - Node fst_mem_rep = RelsUtils::nthElementOfTuple( *mem_rep_it, 0 ); + Node fst_mem_rep = TupleUtils::nthElementOfTuple( *mem_rep_it, 0 ); if( hasChecked.find( fst_mem_rep ) != hasChecked.end() ) { ++mem_rep_it; @@ -333,12 +337,14 @@ void TheorySetsRels::check(Theory::Effort level) std::vector< Node >::iterator mem_rep_exp_it_snd = (*rel_mem_exp_it).second.begin(); while( mem_rep_exp_it_snd != (*rel_mem_exp_it).second.end() ) { - Node fst_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 0 ); + Node fst_element_snd_mem = + TupleUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 0 ); if( areEqual( fst_mem_rep, fst_element_snd_mem ) ) { bool notExist = true; std::vector< Node >::iterator existing_mem_it = existing_members.begin(); - Node snd_element_snd_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 1 ); + Node snd_element_snd_mem = + TupleUtils::nthElementOfTuple( (*mem_rep_exp_it_snd)[0], 1 ); while( existing_mem_it != existing_members.end() ) { if( areEqual( (*existing_mem_it), snd_element_snd_mem ) ) { @@ -410,7 +416,7 @@ void TheorySetsRels::check(Theory::Effort level) Node reason = exp; Node conclusion = d_trueNode; std::vector< Node > distinct_skolems; - Node fst_mem_element = RelsUtils::nthElementOfTuple( exp[0], 0 ); + Node fst_mem_element = TupleUtils::nthElementOfTuple( exp[0], 0 ); if( exp[1] != join_image_term ) { reason = @@ -451,8 +457,8 @@ void TheorySetsRels::check(Theory::Effort level) d_rel_nodes.insert( iden_term ); } Node reason = exp; - Node fst_mem = RelsUtils::nthElementOfTuple( exp[0], 0 ); - Node snd_mem = RelsUtils::nthElementOfTuple( exp[0], 1 ); + Node fst_mem = TupleUtils::nthElementOfTuple( exp[0], 0 ); + Node snd_mem = TupleUtils::nthElementOfTuple( exp[0], 1 ); const DType& dt = iden_term[0].getType().getSetElementType().getDType(); Node fact = nm->mkNode( SET_MEMBER, @@ -489,7 +495,7 @@ void TheorySetsRels::check(Theory::Effort level) while( mem_rep_exp_it != (*rel_mem_exp_it).second.end() ) { Node reason = *mem_rep_exp_it; - Node fst_exp_mem = RelsUtils::nthElementOfTuple( (*mem_rep_exp_it)[0], 0 ); + Node fst_exp_mem = TupleUtils::nthElementOfTuple( (*mem_rep_exp_it)[0], 0 ); Node new_mem = RelsUtils::constructPair( iden_term, fst_exp_mem, fst_exp_mem ); if( (*mem_rep_exp_it)[1] != iden_term_rel ) { @@ -548,8 +554,8 @@ void TheorySetsRels::check(Theory::Effort level) // add mem_rep to d_tcrRep_tcGraph TC_IT tc_it = d_tcr_tcGraph.find( tc_rel ); - Node mem_rep_fst = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 0 ) ); - Node mem_rep_snd = getRepresentative( RelsUtils::nthElementOfTuple( mem_rep, 1 ) ); + Node mem_rep_fst = getRepresentative(TupleUtils::nthElementOfTuple( mem_rep, 0 ) ); + Node mem_rep_snd = getRepresentative(TupleUtils::nthElementOfTuple( mem_rep, 1 ) ); Node mem_rep_tup = RelsUtils::constructPair( tc_rel, mem_rep_fst, mem_rep_snd ); if( tc_it != d_tcr_tcGraph.end() ) { @@ -580,8 +586,8 @@ void TheorySetsRels::check(Theory::Effort level) exp_map[mem_rep_tup] = exp; d_tcr_tcGraph_exps[tc_rel] = exp_map; } - Node fst_element = RelsUtils::nthElementOfTuple( exp[0], 0 ); - Node snd_element = RelsUtils::nthElementOfTuple( exp[0], 1 ); + Node fst_element = TupleUtils::nthElementOfTuple( exp[0], 0 ); + Node snd_element = TupleUtils::nthElementOfTuple( exp[0], 1 ); Node sk_1 = d_skCache.mkTypedSkolemCached(fst_element.getType(), exp[0], tc_rel[0], @@ -631,8 +637,8 @@ void TheorySetsRels::check(Theory::Effort level) if( tc_it != d_rRep_tcGraph.end() ) { bool isReachable = false; std::unordered_set seen; - isTCReachable( getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 0) ), - getRepresentative( RelsUtils::nthElementOfTuple(mem_rep, 1) ), seen, tc_it->second, isReachable ); + isTCReachable( getRepresentative(TupleUtils::nthElementOfTuple(mem_rep, 0) ), + getRepresentative(TupleUtils::nthElementOfTuple(mem_rep, 1) ), seen, tc_it->second, isReachable ); return isReachable; } return false; @@ -680,8 +686,8 @@ void TheorySetsRels::check(Theory::Effort level) for (size_t i = 0, msize = members.size(); i < msize; i++) { - Node fst_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 0 )); - Node snd_element_rep = getRepresentative( RelsUtils::nthElementOfTuple( members[i], 1 )); + Node fst_element_rep = getRepresentative(TupleUtils::nthElementOfTuple( members[i], 0 )); + Node snd_element_rep = getRepresentative(TupleUtils::nthElementOfTuple( members[i], 1 )); Node tuple_rep = RelsUtils::constructPair( rel_rep, fst_element_rep, snd_element_rep ); std::map >::iterator rel_tc_graph_it = rel_tc_graph.find(fst_element_rep); @@ -743,12 +749,15 @@ void TheorySetsRels::check(Theory::Effort level) std::unordered_set& seen) { NodeManager* nm = NodeManager::currentNM(); - Node tc_mem = RelsUtils::constructPair( tc_rel, RelsUtils::nthElementOfTuple((reasons.front())[0], 0), RelsUtils::nthElementOfTuple((reasons.back())[0], 1) ); + Node tc_mem = RelsUtils::constructPair( tc_rel, + TupleUtils::nthElementOfTuple((reasons.front())[0], 0), + TupleUtils::nthElementOfTuple((reasons.back())[0], 1) ); std::vector< Node > all_reasons( reasons ); for( unsigned int i = 0 ; i < reasons.size()-1; i++ ) { - Node fst_element_end = RelsUtils::nthElementOfTuple( reasons[i][0], 1 ); - Node snd_element_begin = RelsUtils::nthElementOfTuple( reasons[i+1][0], 0 ); + Node fst_element_end = TupleUtils::nthElementOfTuple( reasons[i][0], 1 ); + Node snd_element_begin = + TupleUtils::nthElementOfTuple( reasons[i+1][0], 0 ); if( fst_element_end != snd_element_begin ) { all_reasons.push_back( NodeManager::currentNM()->mkNode(kind::EQUAL, fst_element_end, snd_element_begin) ); } @@ -823,12 +832,12 @@ void TheorySetsRels::check(Theory::Effort level) unsigned int i = 0; for(; i < s1_len; ++i) { - r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); + r1_element.push_back(TupleUtils::nthElementOfTuple(mem, i)); } const DType& dt2 = pt_rel[1].getType().getSetElementType().getDType(); r2_element.push_back(dt2[0].getConstructor()); for(; i < tup_len; ++i) { - r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); + r2_element.push_back(TupleUtils::nthElementOfTuple(mem, i)); } Node reason = exp; Node mem1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element); @@ -885,14 +894,14 @@ void TheorySetsRels::check(Theory::Effort level) unsigned int i = 0; r1_element.push_back(dt1[0].getConstructor()); for(; i < s1_len-1; ++i) { - r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); + r1_element.push_back(TupleUtils::nthElementOfTuple(mem, i)); } r1_element.push_back(shared_x); const DType& dt2 = join_rel[1].getType().getSetElementType().getDType(); r2_element.push_back(dt2[0].getConstructor()); r2_element.push_back(shared_x); for(; i < tup_len; ++i) { - r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); + r2_element.push_back(TupleUtils::nthElementOfTuple(mem, i)); } Node mem1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_element); Node mem2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_element); @@ -966,7 +975,7 @@ void TheorySetsRels::check(Theory::Effort level) } Node reason = exp; - Node reversed_mem = RelsUtils::reverseTuple( exp[0] ); + Node reversed_mem = TupleUtils::reverseTuple( exp[0] ); if( tp_rel != exp[1] ) { reason = NodeManager::currentNM()->mkNode(kind::AND, reason, NodeManager::currentNM()->mkNode(kind::EQUAL, tp_rel, exp[1])); @@ -1063,7 +1072,7 @@ void TheorySetsRels::check(Theory::Effort level) kind::AND, reason, nm->mkNode(kind::EQUAL, rel[0], exps[i][1])); } sendInfer( - nm->mkNode(SET_MEMBER, RelsUtils::reverseTuple(exps[i][0]), rel), + nm->mkNode(SET_MEMBER, TupleUtils::reverseTuple(exps[i][0]), rel), InferenceId::SETS_RELS_TRANSPOSE_REV, reason); } @@ -1108,9 +1117,8 @@ void TheorySetsRels::check(Theory::Effort level) std::vector reasons; if (rk == kind::RELATION_JOIN) { - Node r1_rmost = - RelsUtils::nthElementOfTuple(r1_rep_exps[i][0], r1_tuple_len - 1); - Node r2_lmost = RelsUtils::nthElementOfTuple(r2_rep_exps[j][0], 0); + Node r1_rmost = TupleUtils::nthElementOfTuple(r1_rep_exps[i][0], r1_tuple_len - 1); + Node r2_lmost = TupleUtils::nthElementOfTuple(r2_rep_exps[j][0], 0); // Since we require notification r1_rmost and r2_lmost are equal, // they must be shared terms of theory of sets. Hence, we make the // following calls to makeSharedTerm to ensure this is the case. @@ -1140,14 +1148,18 @@ void TheorySetsRels::check(Theory::Effort level) unsigned int l = 1; for( ; k < r1_tuple_len - 1; ++k ) { - tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) ); + tuple_elements.push_back( + TupleUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) ); } if(isProduct) { - tuple_elements.push_back( RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) ); - tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ) ); + tuple_elements.push_back( + TupleUtils::nthElementOfTuple( r1_rep_exps[i][0], k ) ); + tuple_elements.push_back( + TupleUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ) ); } for( ; l < r2_tuple_len; ++l ) { - tuple_elements.push_back( RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], l ) ); + tuple_elements.push_back( + TupleUtils::nthElementOfTuple( r2_rep_exps[j][0], l ) ); } Node composed_tuple = @@ -1216,8 +1228,8 @@ void TheorySetsRels::check(Theory::Effort level) size_t tlen = atn.getTupleLength(); for (size_t i = 0; i < tlen; i++) { - if (!areEqual(RelsUtils::nthElementOfTuple(a, i), - RelsUtils::nthElementOfTuple(b, i))) + if (!areEqual(TupleUtils::nthElementOfTuple(a, i), + TupleUtils::nthElementOfTuple(b, i))) { return false; } @@ -1278,7 +1290,7 @@ void TheorySetsRels::check(Theory::Effort level) void TheorySetsRels::computeTupleReps( Node n ) { if( d_tuple_reps.find( n ) == d_tuple_reps.end() ){ for( unsigned i = 0; i < n.getType().getTupleLength(); i++ ){ - d_tuple_reps[n].push_back( getRepresentative( RelsUtils::nthElementOfTuple(n, i) ) ); + d_tuple_reps[n].push_back( getRepresentative(TupleUtils::nthElementOfTuple(n, i) ) ); } } } @@ -1295,7 +1307,7 @@ void TheorySetsRels::check(Theory::Effort level) std::vector tupleTypes = n[0].getType().getTupleTypes(); for (unsigned int i = 0; i < n[0].getType().getTupleLength(); i++) { - Node element = RelsUtils::nthElementOfTuple(n[0], i); + Node element = TupleUtils::nthElementOfTuple(n[0], i); makeSharedTerm(element, tupleTypes[i]); tuple_elements.push_back(element); } diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index cc642127c..6f6b9c38e 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -16,13 +16,16 @@ #include "theory/sets/theory_sets_rewriter.h" #include "expr/attribute.h" +#include "expr/dtype.h" #include "expr/dtype_cons.h" #include "options/sets_options.h" +#include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" #include "theory/sets/rels_utils.h" #include "util/rational.h" using namespace cvc5::kind; +using namespace cvc5::theory::datatypes; namespace cvc5 { namespace theory { @@ -350,7 +353,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::set::iterator tuple_it = tuple_set.begin(); while(tuple_it != tuple_set.end()) { - new_tuple_set.insert(RelsUtils::reverseTuple(*tuple_it)); + new_tuple_set.insert(TupleUtils::reverseTuple(*tuple_it)); ++tuple_it; } Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType()); @@ -389,7 +392,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::vector left_tuple; left_tuple.push_back(tn.getDType()[0].getConstructor()); for(int i = 0; i < left_len; i++) { - left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i)); + left_tuple.push_back(TupleUtils::nthElementOfTuple(*left_it,i)); } std::set::iterator right_it = right.begin(); int right_len = (*right_it).getType().getTupleLength(); @@ -397,7 +400,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { Trace("rels-debug") << "Sets::postRewrite processing right_it = " << *right_it << std::endl; std::vector right_tuple; for(int j = 0; j < right_len; j++) { - right_tuple.push_back(RelsUtils::nthElementOfTuple(*right_it,j)); + right_tuple.push_back(TupleUtils::nthElementOfTuple(*right_it,j)); } std::vector new_tuple; new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); @@ -437,15 +440,16 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::vector left_tuple; left_tuple.push_back(tn.getDType()[0].getConstructor()); for(int i = 0; i < left_len - 1; i++) { - left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i)); + left_tuple.push_back(TupleUtils::nthElementOfTuple(*left_it,i)); } std::set::iterator right_it = right.begin(); int right_len = (*right_it).getType().getTupleLength(); while(right_it != right.end()) { - if(RelsUtils::nthElementOfTuple(*left_it,left_len-1) == RelsUtils::nthElementOfTuple(*right_it,0)) { + if(TupleUtils::nthElementOfTuple(*left_it,left_len-1) == TupleUtils::nthElementOfTuple(*right_it,0)) { std::vector right_tuple; for(int j = 1; j < right_len; j++) { - right_tuple.push_back(RelsUtils::nthElementOfTuple(*right_it,j)); + right_tuple.push_back( + TupleUtils::nthElementOfTuple(*right_it,j)); } std::vector new_tuple; new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); @@ -508,7 +512,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::set::iterator rel_mems_it = rel_mems.begin(); while( rel_mems_it != rel_mems.end() ) { - Node fst_mem = RelsUtils::nthElementOfTuple( *rel_mems_it, 0); + Node fst_mem = TupleUtils::nthElementOfTuple( *rel_mems_it, 0); iden_rel_mems.insert(RelsUtils::constructPair(node, fst_mem, fst_mem)); ++rel_mems_it; } @@ -548,7 +552,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::set::iterator rel_mems_it = rel_mems.begin(); while( rel_mems_it != rel_mems.end() ) { - Node fst_mem = RelsUtils::nthElementOfTuple( *rel_mems_it, 0); + Node fst_mem = TupleUtils::nthElementOfTuple( *rel_mems_it, 0); if( has_checked.find( fst_mem ) != has_checked.end() ) { ++rel_mems_it; continue; @@ -557,9 +561,10 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::set existing_mems; std::set::iterator rel_mems_it_snd = rel_mems.begin(); while( rel_mems_it_snd != rel_mems.end() ) { - Node fst_mem_snd = RelsUtils::nthElementOfTuple( *rel_mems_it_snd, 0); + Node fst_mem_snd = TupleUtils::nthElementOfTuple( *rel_mems_it_snd, 0); if( fst_mem == fst_mem_snd ) { - existing_mems.insert( RelsUtils::nthElementOfTuple( *rel_mems_it_snd, 1) ); + existing_mems.insert( + TupleUtils::nthElementOfTuple( *rel_mems_it_snd, 1) ); } ++rel_mems_it_snd; } diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index c49d5004c..67603be82 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1678,6 +1678,9 @@ set(regress_1_tests regress1/bags/murxla1.smt2 regress1/bags/murxla2.smt2 regress1/bags/murxla3.smt2 + regress1/bags/product1.smt2 + regress1/bags/product2.smt2 + regress1/bags/product3.smt2 regress1/bags/subbag1.smt2 regress1/bags/subbag2.smt2 regress1/bags/union_disjoint.smt2 diff --git a/test/regress/regress1/bags/product1.smt2 b/test/regress/regress1/bags/product1.smt2 new file mode 100644 index 000000000..2f7f09058 --- /dev/null +++ b/test/regress/regress1/bags/product1.smt2 @@ -0,0 +1,11 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag (Tuple String))) +(declare-fun B () (Bag (Tuple Bool))) +(declare-fun C () (Bag (Tuple String Bool))) +(declare-fun x () (Tuple String)) +(declare-fun y () (Tuple Bool)) +(assert (= (bag.count x A) 5)) +(assert (= (bag.count y B) 4)) +(assert (= C (table.product A B))) +(check-sat) diff --git a/test/regress/regress1/bags/product2.smt2 b/test/regress/regress1/bags/product2.smt2 new file mode 100644 index 000000000..ee7d1712f --- /dev/null +++ b/test/regress/regress1/bags/product2.smt2 @@ -0,0 +1,14 @@ +(set-logic ALL) +(set-info :status unsat) +(declare-fun A () (Bag (Tuple Int Int Int))) +(declare-fun B () (Bag (Tuple Int Int Int))) +(declare-fun x () (Tuple Int Int Int)) +(assert (= x (tuple 1 2 3))) +(declare-fun y () (Tuple Int Int Int)) +(assert (= y (tuple 3 2 1))) +(declare-fun z () (Tuple Int Int Int Int Int Int)) +(assert (= z (tuple 1 2 3 3 2 1))) +(assert (bag.member x A)) +(assert (bag.member y B)) +(assert (not (bag.member z (table.product A B)))) +(check-sat) diff --git a/test/regress/regress1/bags/product3.smt2 b/test/regress/regress1/bags/product3.smt2 new file mode 100644 index 000000000..8f2e8c38f --- /dev/null +++ b/test/regress/regress1/bags/product3.smt2 @@ -0,0 +1,21 @@ +(set-logic ALL) + +(set-info :status sat) + +(declare-fun A () (Bag (Tuple Int Int Int))) +(declare-fun B () (Bag (Tuple Int Int Int))) +(declare-fun C () (Bag (Tuple Int Int Int Int Int Int))) + +(assert (= C (table.product A B))) + +(declare-fun x () (Tuple Int Int Int)) +(declare-fun y () (Tuple Int Int Int)) +(declare-fun z () (Tuple Int Int Int Int Int Int)) + +(assert (bag.member x A)) +(assert (bag.member y B)) +(assert (bag.member z C)) + +(assert (distinct x y ((_ tuple_project 0 1 2) z) ((_ tuple_project 3 4 5) z))) + +(check-sat)