From 7b581e4bd7a139efff494a729eb17e02ad7126fe Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Tue, 19 Apr 2022 13:49:14 -0500 Subject: [PATCH] Add table.project evaluator (#8632) This PR adds evaluator for table.project operator updates the parser to interpret "Table" as a table with zero columns --- src/CMakeLists.txt | 2 + src/api/cpp/cvc5.cpp | 7 +++ src/api/cpp/cvc5_kind.h | 17 +++++- src/parser/smt2/Smt2.g | 14 +++++ src/parser/smt2/smt2.cpp | 3 +- src/printer/smt2/smt2_printer.cpp | 17 ++++++ src/theory/bags/bags_utils.cpp | 32 ++++++++++ src/theory/bags/bags_utils.h | 7 +++ src/theory/bags/kinds | 14 ++++- src/theory/bags/table_project_op.cpp | 25 ++++++++ src/theory/bags/table_project_op.h | 45 ++++++++++++++ src/theory/bags/theory_bags.cpp | 3 + src/theory/bags/theory_bags_type_rules.cpp | 60 +++++++++++++++++++ src/theory/bags/theory_bags_type_rules.h | 11 ++++ src/theory/datatypes/datatypes_rewriter.cpp | 23 +------ .../datatypes/theory_datatypes_type_rules.cpp | 10 +--- .../datatypes/theory_datatypes_type_rules.h | 2 +- src/theory/datatypes/tuple_project_op.cpp | 17 ++++-- src/theory/datatypes/tuple_project_op.h | 37 ++++++++---- src/theory/datatypes/tuple_utils.cpp | 42 +++++++++++++ src/theory/datatypes/tuple_utils.h | 16 +++++ test/regress/cli/CMakeLists.txt | 1 + .../cli/regress1/bags/table_project1.smt2 | 21 +++++++ 23 files changed, 378 insertions(+), 48 deletions(-) create mode 100644 src/theory/bags/table_project_op.cpp create mode 100644 src/theory/bags/table_project_op.h create mode 100644 test/regress/cli/regress1/bags/table_project1.smt2 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f414db97d..ea10d8fcf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -571,6 +571,8 @@ libcvc5_add_sources( theory/bags/solver_state.h theory/bags/strategy.cpp theory/bags/strategy.h + theory/bags/table_project_op.cpp + theory/bags/table_project_op.h theory/bags/term_registry.cpp theory/bags/term_registry.h theory/bags/theory_bags.cpp diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 5de54018c..d3c28aa06 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -69,6 +69,7 @@ #include "smt/model.h" #include "smt/smt_mode.h" #include "smt/solver_engine.h" +#include "theory/bags/table_project_op.h" #include "theory/datatypes/tuple_project_op.h" #include "theory/logic_info.h" #include "theory/theory_model.h" @@ -328,6 +329,7 @@ const static std::unordered_map> KIND_ENUM(BAG_FILTER, internal::Kind::BAG_FILTER), KIND_ENUM(BAG_FOLD, internal::Kind::BAG_FOLD), KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT), + KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT), /* Strings ---------------------------------------------------------- */ KIND_ENUM(STRING_CONCAT, internal::Kind::STRING_CONCAT), KIND_ENUM(STRING_IN_REGEXP, internal::Kind::STRING_IN_REGEXP), @@ -643,6 +645,8 @@ const static std::unordered_map& args) const case TUPLE_PROJECT: res = mkOpHelper(kind, internal::TupleProjectOp(args)); break; + case TABLE_PROJECT: + res = mkOpHelper(kind, internal::TableProjectOp(args)); + break; default: if (nargs == 0) { diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index e1ba67405..bbc5cfdd8 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -3695,7 +3695,22 @@ enum Kind : int32_t * \endrst */ TABLE_PRODUCT, - + /** + * Table projection operator extends tuple projection operator to tables. + * + * - Arity: ``1`` + * - ``1:`` Term of tuple Sort + * + * - Indices: ``n`` + * - ``1..n:`` The table indices to project + * + * - Create Term of this Kind with: + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * - Create Op of this kind with: + * - Solver::mkOp(Kind, const std::vector&) const + */ + TABLE_PROJECT, /* Strings --------------------------------------------------------------- */ /** diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 224da51f4..e66636e22 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1448,6 +1448,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2] cvc5::Op op = SOLVER->mkOp(cvc5::TUPLE_PROJECT, indices); expr = SOLVER->mkTerm(op, {expr}); } + | LPAREN_TOK TABLE_PROJECT_TOK term[expr,expr2] RPAREN_TOK + { + std::vector indices; + cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_PROJECT, indices); + expr = SOLVER->mkTerm(op, {expr}); + } | /* an atomic term (a term with no subterms) */ termAtomic[atomTerm] { expr = atomTerm; } ; @@ -1589,6 +1595,13 @@ identifier[cvc5::ParseOp& p] p.d_kind = cvc5::TUPLE_PROJECT; p.d_op = SOLVER->mkOp(cvc5::TUPLE_PROJECT, numerals); } + | TABLE_PROJECT_TOK nonemptyNumeralList[numerals] + { + // we adopt a special syntax (_ table.project i_1 ... i_n) where + // i_1, ..., i_n are numerals + p.d_kind = cvc5::TABLE_PROJECT; + p.d_op = SOLVER->mkOp(cvc5::TABLE_PROJECT, numerals); + } | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals] { cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName); @@ -2197,6 +2210,7 @@ FORALL_TOK : 'forall'; CHAR_TOK : { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_STRINGS) }? 'char'; TUPLE_CONST_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple'; TUPLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_DATATYPES) }? 'tuple.project'; +TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.project'; FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card'; HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->'; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 41a98c613..05a1e1982 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -258,6 +258,7 @@ void Smt2::addSepOperators() { void Smt2::addCoreSymbols() { defineType("Bool", d_solver->getBooleanSort(), true); + defineType("Table", d_solver->mkBagSort(d_solver->mkTupleSort({})), true); defineVar("true", d_solver->mkTrue(), true); defineVar("false", d_solver->mkFalse(), true); addOperator(cvc5::AND, "and"); @@ -1123,7 +1124,7 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) Trace("parser") << "applyParseOp: return selector " << ret << std::endl; return ret; } - else if (p.d_kind == cvc5::TUPLE_PROJECT) + else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT) { cvc5::Term ret = d_solver->mkTerm(p.d_op, args); Trace("parser") << "applyParseOp: return projection " << ret << std::endl; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 91a3f2186..7d5dc1cc5 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -42,6 +42,7 @@ #include "proof/unsat_core.h" #include "smt/command.h" #include "smt_util/boolean_simplification.h" +#include "theory/bags/table_project_op.h" #include "theory/arrays/theory_arrays_rewriter.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/tuple_project_op.h" @@ -792,6 +793,21 @@ void Smt2Printer::toStream(std::ostream& out, } return; } + case kind::TABLE_PROJECT: + { + TableProjectOp op = n.getOperator().getConst(); + if (op.getIndices().empty()) + { + // e.g. (table.project A) + out << "table.project " << n[0] << ")"; + } + else + { + // e.g. ((_ table.project 2 4 4) A) + out << "(_ table.project" << op << ") " << n[0] << ")"; + } + return; + } case kind::CONSTRUCTOR_TYPE: { out << n[n.getNumChildren()-1]; @@ -1168,6 +1184,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_FILTER: return "bag.filter"; case kind::BAG_FOLD: return "bag.fold"; case kind::TABLE_PRODUCT: return "table.product"; + case kind::TABLE_PROJECT: return "table.project"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index 3c5089943..e71923248 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -18,6 +18,7 @@ #include "expr/dtype_cons.h" #include "expr/emptybag.h" #include "smt/logic_exception.h" +#include "table_project_op.h" #include "theory/datatypes/tuple_utils.h" #include "theory/sets/normal_form.h" #include "theory/type_enumerator.h" @@ -141,6 +142,7 @@ Node BagsUtils::evaluate(TNode n) case BAG_FILTER: return evaluateBagFilter(n); case BAG_FOLD: return evaluateBagFold(n); case TABLE_PRODUCT: return evaluateProduct(n); + case TABLE_PROJECT: return evaluateTableProject(n); default: break; } Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n @@ -829,6 +831,36 @@ Node BagsUtils::evaluateProduct(TNode n) return ret; } +Node BagsUtils::evaluateTableProject(TNode n) +{ + Assert(n.getKind() == TABLE_PROJECT); + // Examples + // -------- + // - ((_ table.project 1) (bag (tuple true "a") 4)) = (bag (tuple "a") 4) + // - (table.project (bag.union_disjoint + // (bag (tuple "a") 4) + // (bag (tuple "b") 3))) = (bag tuple 7) + + Node A = n[0]; + + std::map elementsA = BagsUtils::getBagElements(A); + + std::map elements; + std::vector indices = + n.getOperator().getConst().getIndices(); + + for (const auto& [a, countA] : elementsA) + { + Node element = TupleUtils::getTupleProjection(indices, a); + // multiple elements could be projected to the same tuple. + // Zero is the default value for Rational values. + elements[element] += countA; + } + + Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements); + return ret; +} + } // namespace bags } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h index 41b8c14d3..42e7b0caf 100644 --- a/src/theory/bags/bags_utils.h +++ b/src/theory/bags/bags_utils.h @@ -109,6 +109,13 @@ class BagsUtils */ static Node evaluateProduct(TNode n); + /** + * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a + * constant + * @return the evaluation of the projection + */ + static Node evaluateTableProject(TNode n); + private: /** * a high order helper function that return a constant bag that is the result diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 49e1d5624..49bca83fb 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -114,9 +114,21 @@ construle BAG_UNION_DISJOINT ::cvc5::internal::theory::bags::BinaryOperatorT construle BAG_MAKE ::cvc5::internal::theory::bags::BagMakeTypeRule -# bag.product operator returns the cross product of two tables +# table.product operator returns the cross product of two tables operator TABLE_PRODUCT 2 "table cross product" +# table.project operator extends datatypes tuple_project operator to a bag of tuples +constant TABLE_PROJECT_OP \ + class \ + TableProjectOp \ + ::cvc5::internal::TableProjectOpHashFunction \ + "theory/bags/table_project_op.h" \ + "operator for TABLE_PROJECT; payload is an instance of the cvc5::internal::TableProjectOp class" + +parameterized TABLE_PROJECT TABLE_PROJECT_OP 1 "table projection" + typerule TABLE_PRODUCT ::cvc5::internal::theory::bags::TableProductTypeRule +typerule TABLE_PROJECT_OP "SimpleTypeRule" +typerule TABLE_PROJECT ::cvc5::internal::theory::bags::TableProjectTypeRule endtheory diff --git a/src/theory/bags/table_project_op.cpp b/src/theory/bags/table_project_op.cpp new file mode 100644 index 000000000..426753d8a --- /dev/null +++ b/src/theory/bags/table_project_op.cpp @@ -0,0 +1,25 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * A class for TableProjectOp operator. + */ + +#include "table_project_op.h" + +namespace cvc5::internal { + +TableProjectOp::TableProjectOp(std::vector indices) + : ProjectOp(std::move(indices)) +{ +} + +} // namespace cvc5::internal diff --git a/src/theory/bags/table_project_op.h b/src/theory/bags/table_project_op.h new file mode 100644 index 000000000..a061537e9 --- /dev/null +++ b/src/theory/bags/table_project_op.h @@ -0,0 +1,45 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2022 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * A class for TableProjectOp operator. + */ + +#include "cvc5_public.h" + +#ifndef CVC5__TABLE_PROJECT_OP_H +#define CVC5__TABLE_PROJECT_OP_H + +#include "theory/datatypes/tuple_project_op.h" + +namespace cvc5::internal { + +/** + * The class is an operator for kind project used to project elements in a + * table. It stores the indices of projected elements + */ +class TableProjectOp : public ProjectOp +{ + public: + explicit TableProjectOp(std::vector indices); + TableProjectOp(const TableProjectOp& op) = default; +}; /* class TableProjectOp */ + +/** + * Hash function for the TupleProjectOpHashFunction objects. + */ +struct TableProjectOpHashFunction : public ProjectOpHashFunction +{ +}; /* struct TupleProjectOpHashFunction */ + +} // namespace cvc5::internal + +#endif /* CVC5__TABLE_PROJECT_OP_H */ diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 4307dcbe3..92ea5ecca 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -80,6 +80,8 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(BAG_CARD); d_equalityEngine->addFunctionKind(BAG_FROM_SET); d_equalityEngine->addFunctionKind(BAG_TO_SET); + d_equalityEngine->addFunctionKind(TABLE_PRODUCT); + d_equalityEngine->addFunctionKind(TABLE_PROJECT); } TrustNode TheoryBags::ppRewrite(TNode atom, std::vector& lems) @@ -453,6 +455,7 @@ void TheoryBags::preRegisterTerm(TNode n) case BAG_FROM_SET: case BAG_TO_SET: case BAG_IS_SINGLETON: + case TABLE_PROJECT: { std::stringstream ss; ss << "Term of kind " << n.getKind() << " is not supported yet"; diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index fd47f006b..ef2a5a350 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -18,9 +18,14 @@ #include #include "base/check.h" +#include "expr/dtype.h" +#include "expr/dtype_cons.h" #include "expr/emptybag.h" +#include "table_project_op.h" #include "theory/bags/bag_make_op.h" #include "theory/bags/bags_utils.h" +#include "theory/datatypes/tuple_project_op.h" +#include "theory/datatypes/tuple_utils.h" #include "util/cardinality.h" #include "util/rational.h" @@ -494,6 +499,61 @@ TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode TableProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check) +{ + Assert(n.getKind() == kind::TABLE_PROJECT && n.hasOperator() + && n.getOperator().getKind() == kind::TABLE_PROJECT_OP); + TableProjectOp op = n.getOperator().getConst(); + const std::vector& indices = op.getIndices(); + TypeNode bagType = n[0].getType(check); + if (check) + { + if (n.getNumChildren() != 1) + { + std::stringstream ss; + ss << "operands in term " << n << " are " << n.getNumChildren() + << ", but TABLE_PROJECT expects 1 operand."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + if (!bagType.isBag()) + { + std::stringstream ss; + ss << "TABLE_PROJECT operator expects a table. Found '" << n[0] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + TypeNode tupleType = bagType.getBagElementType(); + if (!tupleType.isTuple()) + { + std::stringstream ss; + ss << "TABLE_PROJECT operator expects a table. Found '" << n[0] + << "' of type '" << bagType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + + // make sure all indices are less than the length of the tuple type + DType dType = tupleType.getDType(); + DTypeConstructor constructor = dType[0]; + size_t numArgs = constructor.getNumArgs(); + for (uint32_t index : indices) + { + std::stringstream ss; + if (index >= numArgs) + { + ss << "Index " << index << " in term " << n << " is >= " << numArgs + << " which is the number of columns in " << n[0] << "."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + } + TypeNode tupleType = bagType.getBagElementType(); + TypeNode retTupleType = + datatypes::TupleUtils::getTupleProjectionType(indices, tupleType); + return nm->mkBagType(retTupleType); +} + Cardinality BagsProperties::computeCardinality(TypeNode type) { return Cardinality::INTEGERS; diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index a80523415..54329b405 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -168,6 +168,17 @@ struct TableProductTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagFoldTypeRule */ +/** + * Table project is indexed by a list of indices (n_1, ..., n_m). It ensures + * that the argument is a bag of tuples whose arity k is greater than each n_i + * for i = 1, ..., m. If the argument is of type (Bag (Tuple T_1 ... T_k)), then + * the returned type is (Bag (Tuple T_{n_1} ... T_{n_m})). + */ +struct TableProjectTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFoldTypeRule */ + struct BagsProperties { static Cardinality computeCardinality(TypeNode type); diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index fccac800d..307e7cb61 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -26,6 +26,7 @@ #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/datatypes/tuple_project_op.h" +#include "tuple_utils.h" #include "util/rational.h" #include "util/uninterpreted_sort_value.h" @@ -165,29 +166,11 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) // where each i_j is less than the length of t Trace("dt-rewrite-project") << "Rewrite project: " << in << std::endl; + TupleProjectOp op = in.getOperator().getConst(); std::vector indices = op.getIndices(); Node tuple = in[0]; - std::vector tupleTypes = tuple.getType().getTupleTypes(); - std::vector types; - std::vector elements; - for (uint32_t index : indices) - { - TypeNode type = tupleTypes[index]; - types.push_back(type); - } - TypeNode projectType = nm->mkTupleType(types); - const DType& dt = projectType.getDType(); - elements.push_back(dt[0].getConstructor()); - const DType& tupleDType = tuple.getType().getDType(); - const DTypeConstructor& constructor = tupleDType[0]; - for (uint32_t index : indices) - { - Node selector = constructor[index].getSelector(); - Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple); - elements.push_back(element); - } - Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements); + Node ret = TupleUtils::getTupleProjection(indices, tuple); Trace("dt-rewrite-project") << "Rewrite project: " << in << " ... " << ret << std::endl; diff --git a/src/theory/datatypes/theory_datatypes_type_rules.cpp b/src/theory/datatypes/theory_datatypes_type_rules.cpp index 6f27fdce8..f4450a57d 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.cpp +++ b/src/theory/datatypes/theory_datatypes_type_rules.cpp @@ -24,6 +24,7 @@ #include "expr/type_matcher.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/datatypes/tuple_project_op.h" +#include "theory/datatypes/tuple_utils.h" #include "util/rational.h" namespace cvc5::internal { @@ -560,14 +561,7 @@ TypeNode TupleProjectTypeRule::computeType(NodeManager* nm, TNode n, bool check) } } TypeNode tupleType = n[0].getType(check); - std::vector types; - DType dType = tupleType.getDType(); - DTypeConstructor constructor = dType[0]; - for (uint32_t index : indices) - { - types.push_back(constructor.getArgType(index)); - } - return nm->mkTupleType(types); + return TupleUtils::getTupleProjectionType(indices, tupleType); } TypeNode CodatatypeBoundVariableTypeRule::computeType(NodeManager* nodeManager, diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index f5d77d1ac..0696edc4e 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -161,7 +161,7 @@ class MatchBindCaseTypeRule /** * Tuple project is indexed by a list of indices (n_1, ..., n_m). It ensures - * that the argument is a tuple whose arity k is greater that each n_i for + * that the argument is a tuple whose arity k is greater than each n_i for * i = 1, ..., m. If the argument is of type (Tuple T_1 ... T_k), then the * returned type is (Tuple T_{n_1} ... T_{n_m}). */ diff --git a/src/theory/datatypes/tuple_project_op.cpp b/src/theory/datatypes/tuple_project_op.cpp index 335a03167..28027528f 100644 --- a/src/theory/datatypes/tuple_project_op.cpp +++ b/src/theory/datatypes/tuple_project_op.cpp @@ -10,7 +10,7 @@ * directory for licensing information. * **************************************************************************** * - * A class for TupleProjectOp operator. + * A class for ProjectOp operator. */ #include "tuple_project_op.h" @@ -21,7 +21,7 @@ namespace cvc5::internal { -std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op) +std::ostream& operator<<(std::ostream& out, const ProjectOp& op) { for (const uint32_t& index : op.getIndices()) { @@ -30,7 +30,7 @@ std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op) return out; } -size_t TupleProjectOpHashFunction::operator()(const TupleProjectOp& op) const +size_t ProjectOpHashFunction::operator()(const ProjectOp& op) const { // we expect most tuples to have length < 10. // Therefore we can implement a simple hash function @@ -42,16 +42,21 @@ size_t TupleProjectOpHashFunction::operator()(const TupleProjectOp& op) const return hash; } -TupleProjectOp::TupleProjectOp(std::vector indices) +ProjectOp::ProjectOp(std::vector indices) : d_indices(std::move(indices)) { } -const std::vector& TupleProjectOp::getIndices() const { return d_indices; } +const std::vector& ProjectOp::getIndices() const { return d_indices; } -bool TupleProjectOp::operator==(const TupleProjectOp& op) const +bool ProjectOp::operator==(const ProjectOp& op) const { return d_indices == op.d_indices; } +TupleProjectOp::TupleProjectOp(std::vector indices) + : ProjectOp(std::move(indices)) +{ +} + } // namespace cvc5::internal diff --git a/src/theory/datatypes/tuple_project_op.h b/src/theory/datatypes/tuple_project_op.h index 269c38b17..7d1b46ff9 100644 --- a/src/theory/datatypes/tuple_project_op.h +++ b/src/theory/datatypes/tuple_project_op.h @@ -26,32 +26,49 @@ namespace cvc5::internal { class TypeNode; /** - * The class is an operator for kind project used to project elements in a tuple - * It stores the indices of projected elements + * base class for TupleProjectOp, TupleProjectOp */ -class TupleProjectOp +class ProjectOp { public: - explicit TupleProjectOp(std::vector indices); - TupleProjectOp(const TupleProjectOp& op) = default; + explicit ProjectOp(std::vector indices); + ProjectOp(const ProjectOp& op) = default; /** return the indices of the projection */ const std::vector& getIndices() const; - bool operator==(const TupleProjectOp& op) const; + bool operator==(const ProjectOp& op) const; private: std::vector d_indices; -}; /* class TupleProjectOp */ +}; /* class ProjectOp */ + +std::ostream& operator<<(std::ostream& out, const ProjectOp& op); + +/** + * Hash function for the ProjectOpHashFunction objects. + */ +struct ProjectOpHashFunction +{ + size_t operator()(const ProjectOp& op) const; +}; /* struct ProjectOpHashFunction */ -std::ostream& operator<<(std::ostream& out, const TupleProjectOp& op); +/** + * The class is an operator for kind project used to project elements in a + * table. It stores the indices of projected elements + */ +class TupleProjectOp : public ProjectOp +{ + public: + explicit TupleProjectOp(std::vector indices); + TupleProjectOp(const TupleProjectOp& op) = default; +}; /* class TupleProjectOp */ /** * Hash function for the TupleProjectOpHashFunction objects. */ -struct TupleProjectOpHashFunction +struct TupleProjectOpHashFunction : public ProjectOpHashFunction { - size_t operator()(const TupleProjectOp& op) const; }; /* struct TupleProjectOpHashFunction */ } // namespace cvc5::internal diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp index 87114d9b1..05528a644 100644 --- a/src/theory/datatypes/tuple_utils.cpp +++ b/src/theory/datatypes/tuple_utils.cpp @@ -36,6 +36,48 @@ Node TupleUtils::nthElementOfTuple(Node tuple, int n_th) APPLY_SELECTOR, dt[0].getSelectorInternal(tn, n_th), tuple); } +Node TupleUtils::getTupleProjection(const std::vector& indices, + Node tuple) +{ + std::vector tupleTypes = tuple.getType().getTupleTypes(); + std::vector types; + std::vector elements; + for (uint32_t index : indices) + { + TypeNode type = tupleTypes[index]; + types.push_back(type); + } + NodeManager* nm = NodeManager::currentNM(); + TypeNode projectType = nm->mkTupleType(types); + const DType& dt = projectType.getDType(); + elements.push_back(dt[0].getConstructor()); + const DType& tupleDType = tuple.getType().getDType(); + const DTypeConstructor& constructor = tupleDType[0]; + for (uint32_t index : indices) + { + Node selector = constructor[index].getSelector(); + Node element = nm->mkNode(kind::APPLY_SELECTOR, selector, tuple); + elements.push_back(element); + } + Node ret = nm->mkNode(kind::APPLY_CONSTRUCTOR, elements); + return ret; +} + +TypeNode TupleUtils::getTupleProjectionType( + const std::vector& indices, TypeNode tupleType) +{ + std::vector types; + DType dType = tupleType.getDType(); + DTypeConstructor constructor = dType[0]; + for (uint32_t index : indices) + { + types.push_back(constructor.getArgType(index)); + } + NodeManager* nm = NodeManager::currentNM(); + TypeNode retTupleType = nm->mkTupleType(types); + return retTupleType; +} + std::vector TupleUtils::getTupleElements(Node tuple) { Assert(tuple.getType().isTuple()); diff --git a/src/theory/datatypes/tuple_utils.h b/src/theory/datatypes/tuple_utils.h index f6651c50f..041121397 100644 --- a/src/theory/datatypes/tuple_utils.h +++ b/src/theory/datatypes/tuple_utils.h @@ -33,6 +33,22 @@ class TupleUtils */ static Node nthElementOfTuple(Node tuple, int n_th); + /** + * @param indices a list of indices for projected elements + * @param tuple a node of tuple type + * @return the projection of the tuple with the specified indices + */ + static Node getTupleProjection(const std::vector& indices, + Node tuple); + + /** + * @param indices a list of indices for projected elements + * @param tupleType the type of the original tuple + * @return the type of the projected tuple + */ + static TypeNode getTupleProjectionType(const std::vector& indices, + TypeNode tupleType); + /** * @param tuple a tuple node of the form (tuple a_1 ... a_n) * @return the vector [a_1, ... a_n] diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 871fe3b66..dd4039cf0 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -1798,6 +1798,7 @@ set(regress_1_tests regress1/bags/proj-issue497.smt2 regress1/bags/subbag1.smt2 regress1/bags/subbag2.smt2 + regress1/bags/table_project1.smt2 regress1/bags/union_disjoint.smt2 regress1/bags/union_max1.smt2 regress1/bags/union_max2.smt2 diff --git a/test/regress/cli/regress1/bags/table_project1.smt2 b/test/regress/cli/regress1/bags/table_project1.smt2 new file mode 100644 index 000000000..882cf48ba --- /dev/null +++ b/test/regress/cli/regress1/bags/table_project1.smt2 @@ -0,0 +1,21 @@ +(set-logic HO_ALL) +(set-info :status sat) + +(declare-fun A () (Table String Int String Bool)) +(declare-fun B () (Table Int Bool String String)) +(declare-fun C () (Table String String)) +(declare-fun D () Table) + +(assert + (= A + (bag.union_disjoint + (bag (tuple "x" 0 "y" false) 5) + (bag (tuple "x" 1 "z" true) 10)))) + +; (bag.union_disjoint (bag (tuple 0 false "x" "y") 5) (bag (tuple 1 true "x" "z") 10))) +(assert (= B ((_ table.project 1 3 0 2) A))) +; (bag (tuple "x" "x") 15) +(assert (= C ((_ table.project 0 0) A))) +; (bag tuple 15) +(assert (= D (table.project A))) +(check-sat) -- 2.30.2