From: mudathirmahgoub Date: Fri, 8 Jan 2021 16:07:50 +0000 (-0600) Subject: Add bags inference generator (#5731) X-Git-Tag: cvc5-1.0.0~2392 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=63d27f031f8942607d869080d0e2cfb6078d40b1;p=cvc5.git Add bags inference generator (#5731) This PR adds inference generator for basic bag rules. --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f01c948db..7e294443c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -456,8 +456,14 @@ libcvc4_add_sources( theory/atom_requests.h theory/bags/bags_rewriter.cpp theory/bags/bags_rewriter.h + theory/bags/bag_solver.cpp + theory/bags/bag_solver.h theory/bags/bags_statistics.cpp theory/bags/bags_statistics.h + theory/bags/infer_info.cpp + theory/bags/infer_info.h + theory/bags/inference_generator.cpp + theory/bags/inference_generator.h theory/bags/inference_manager.cpp theory/bags/inference_manager.h theory/bags/make_bag_op.cpp diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index b11a01628..49974d30d 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -3418,6 +3418,20 @@ Term Solver::mkTermHelper(Kind kind, const std::vector& children) const Node singleton = getNodeManager()->mkSingleton(type, *children[0].d_node); res = Term(this, singleton).getExpr(); } + else if (kind == api::MK_BAG) + { + // the type of the term is the same as the type of the internal node + // see Term::getSort() + TypeNode type = children[0].d_node->getType(); + // Internally NodeManager::mkBag needs a type argument + // to construct a bag, since there is no difference between + // integers and reals (both are Rationals). + // At the API, mkReal and mkInteger are different and therefore the + // element type can be used safely here. + Node bag = getNodeManager()->mkBag( + type, *children[0].d_node, *children[1].d_node); + res = Term(this, bag).getExpr(); + } else { res = d_exprMgr->mkExpr(k, echildren); diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 49fbe73ef..3c038546a 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -151,12 +151,12 @@ bool TypeNode::isFiniteInternal(bool usortFinite) TypeNode tnc = getArrayConstituentType(); if (!tnc.isFiniteInternal(usortFinite)) { - // arrays with consistuent type that is infinite are infinite + // arrays with constituent type that is infinite are infinite ret = false; } else if (getArrayIndexType().isFiniteInternal(usortFinite)) { - // arrays with both finite consistuent and index types are finite + // arrays with both finite constituent and index types are finite ret = true; } else @@ -170,6 +170,11 @@ bool TypeNode::isFiniteInternal(bool usortFinite) { ret = getSetElementType().isFiniteInternal(usortFinite); } + else if (isBag()) + { + // there are infinite bags for all element types + ret = false; + } else if (isFunction()) { ret = true; diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index f069a486f..36703fd6d 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -642,7 +642,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::INTERSECTION_MIN, "intersection_min"); addOperator(api::DIFFERENCE_SUBTRACT, "difference_subtract"); addOperator(api::DIFFERENCE_REMOVE, "difference_remove"); - addOperator(api::SUBBAG, "bag.is_included"); + addOperator(api::SUBBAG, "subbag"); addOperator(api::BAG_COUNT, "bag.count"); addOperator(api::DUPLICATE_REMOVAL, "duplicate_removal"); addOperator(api::MK_BAG, "bag"); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 50bb79a9a..ccf9c4164 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -813,7 +813,30 @@ void Smt2Printer::toStream(std::ostream& out, case kind::UNIVERSE_SET:out << "(as univset " << n.getType() << ")";break; // bags - case kind::BAG_TYPE: out << smtKindString(k, d_variant) << " "; break; + case kind::BAG_TYPE: + case kind::UNION_MAX: + case kind::UNION_DISJOINT: + case kind::INTERSECTION_MIN: + case kind::DIFFERENCE_SUBTRACT: + case kind::DIFFERENCE_REMOVE: + case kind::SUBBAG: + case kind::BAG_COUNT: + case kind::DUPLICATE_REMOVAL: + case kind::BAG_CARD: + case kind::BAG_CHOOSE: + case kind::BAG_IS_SINGLETON: + case kind::BAG_FROM_SET: + case kind::BAG_TO_SET: out << smtKindString(k, d_variant) << " "; break; + case kind::MK_BAG: + { + // print (bag (mkBag_op Real) 1 3) as (bag 1.0 3) + out << smtKindString(k, d_variant) << " "; + TypeNode elemType = n.getType().getBagElementType(); + toStreamCastToType( + out, n[0], toDepth < 0 ? toDepth : toDepth - 1, elemType); + out << " " << n[1] << ")"; + return; + } // fp theory case kind::FLOATINGPOINT_FP: @@ -1170,6 +1193,20 @@ static string smtKindString(Kind k, Variant v) // bag theory case kind::BAG_TYPE: return "Bag"; + case kind::UNION_MAX: return "union_max"; + case kind::UNION_DISJOINT: return "union_disjoint"; + case kind::INTERSECTION_MIN: return "intersection_min"; + case kind::DIFFERENCE_SUBTRACT: return "difference_subtract"; + case kind::DIFFERENCE_REMOVE: return "difference_remove"; + case kind::SUBBAG: return "subbag"; + case kind::BAG_COUNT: return "bag.count"; + case kind::DUPLICATE_REMOVAL: return "duplicate_removal"; + case kind::MK_BAG: return "bag"; + case kind::BAG_CARD: return "bag.card"; + case kind::BAG_CHOOSE: return "bag.choose"; + case kind::BAG_IS_SINGLETON: return "bag.is_singleton"; + case kind::BAG_FROM_SET: return "bag.from_set"; + case kind::BAG_TO_SET: return "bag.to_set"; // fp theory case kind::FLOATINGPOINT_FP: return "fp"; diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp new file mode 100644 index 000000000..5621a7c1c --- /dev/null +++ b/src/theory/bags/bag_solver.cpp @@ -0,0 +1,109 @@ +/********************* */ +/*! \file bag_solver.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief solver for the theory of bags. + ** + ** solver for the theory of bags. + **/ + +#include "theory/bags/bag_solver.h" + +#include "theory/bags/inference_generator.h" + +using namespace std; +using namespace CVC4::context; +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace bags { + +BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr) + : d_state(s), d_im(im), d_termReg(tr) +{ + d_zero = NodeManager::currentNM()->mkConst(Rational(0)); + d_one = NodeManager::currentNM()->mkConst(Rational(1)); + d_true = NodeManager::currentNM()->mkConst(true); + d_false = NodeManager::currentNM()->mkConst(false); +} + +BagSolver::~BagSolver() {} + +void BagSolver::postCheck() +{ + for (const Node& n : d_state.getBags()) + { + Kind k = n.getKind(); + switch (k) + { + case kind::MK_BAG: checkMkBag(n); break; + case kind::UNION_DISJOINT: checkUnionDisjoint(n); break; + case kind::UNION_MAX: checkUnionMax(n); break; + case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; + default: break; + } + } +} + +void BagSolver::checkUnionDisjoint(const Node& n) +{ + Assert(n.getKind() == UNION_DISJOINT); + TypeNode elementType = n.getType().getBagElementType(); + for (const Node& e : d_state.getElements(elementType)) + { + InferenceGenerator ig(&d_state); + InferInfo i = ig.unionDisjoint(n, e); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; + } +} + +void BagSolver::checkUnionMax(const Node& n) +{ + Assert(n.getKind() == UNION_MAX); + TypeNode elementType = n.getType().getBagElementType(); + for (const Node& e : d_state.getElements(elementType)) + { + InferenceGenerator ig(&d_state); + InferInfo i = ig.unionMax(n, e); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; + } +} + +void BagSolver::checkDifferenceSubtract(const Node& n) +{ + Assert(n.getKind() == DIFFERENCE_SUBTRACT); + TypeNode elementType = n.getType().getBagElementType(); + for (const Node& e : d_state.getElements(elementType)) + { + InferenceGenerator ig(&d_state); + InferInfo i = ig.differenceSubtract(n, e); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; + } +} +void BagSolver::checkMkBag(const Node& n) +{ + Assert(n.getKind() == MK_BAG); + TypeNode elementType = n.getType().getBagElementType(); + for (const Node& e : d_state.getElements(elementType)) + { + InferenceGenerator ig(&d_state); + InferInfo i = ig.mkBag(n, e); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; + } +} + +} // namespace bags +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h new file mode 100644 index 000000000..48583d134 --- /dev/null +++ b/src/theory/bags/bag_solver.h @@ -0,0 +1,70 @@ +/********************* */ +/*! \file bag_solver.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief solver for the theory of bags. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__BAG__SOLVER_H +#define CVC4__THEORY__BAG__SOLVER_H + +#include "context/cdhashset.h" +#include "context/cdlist.h" +#include "theory/bags/infer_info.h" +#include "theory/bags/inference_manager.h" +#include "theory/bags/normal_form.h" +#include "theory/bags/solver_state.h" +#include "theory/bags/term_registry.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +/** The solver for the theory of bags + * + */ +class BagSolver +{ + public: + BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr); + ~BagSolver(); + + void postCheck(); + + private: + /** apply inference rules for MK_BAG operator */ + void checkMkBag(const Node& n); + /** apply inference rules for union disjoint */ + void checkUnionDisjoint(const Node& n); + /** apply inference rules for union max */ + void checkUnionMax(const Node& n); + /** apply inference rules for difference subtract */ + void checkDifferenceSubtract(const Node& n); + + /** The solver state object */ + SolverState& d_state; + /** Reference to the inference manager for the theory of bags */ + InferenceManager& d_im; + /** Reference to the term registry of theory of bags */ + TermRegistry& d_termReg; + /** Commonly used constants */ + Node d_true; + Node d_false; + Node d_zero; + Node d_one; +}; /* class BagSolver */ + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAG__SOLVER_H */ diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index aee57c74d..66886bfbf 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -41,6 +41,8 @@ BagsRewriter::BagsRewriter(HistogramStat* statistics) : d_statistics(statistics) { d_nm = NodeManager::currentNM(); + d_zero = d_nm->mkConst(Rational(0)); + d_one = d_nm->mkConst(Rational(1)); } RewriteResponse BagsRewriter::postRewrite(TNode n) @@ -51,7 +53,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) // no need to rewrite n if it is already in a normal form response = BagsRewriteResponse(n, Rewrite::NONE); } - else if(n.getKind() == EQUAL) + else if (n.getKind() == EQUAL) { response = postRewriteEqual(n); } @@ -162,12 +164,11 @@ BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const if (n[1].isConst() && n[1].getKind() == EMPTYBAG) { // (bag.count x emptybag) = 0 - return BagsRewriteResponse(d_nm->mkConst(Rational(0)), - Rewrite::COUNT_EMPTY); + return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY); } if (n[1].getKind() == MK_BAG && n[0] == n[1][0]) { - // (bag.count x (mkBag x c) = c where c > 0 is a constant + // (bag.count x (mkBag x c) = c return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG); } return BagsRewriteResponse(n, Rewrite::NONE); @@ -181,8 +182,7 @@ BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const { // (duplicate_removal (mkBag x n)) = (mkBag x 1) // where n is a positive constant - Node one = NodeManager::currentNM()->mkConst(Rational(1)); - Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], one); + Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one); return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG); } return BagsRewriteResponse(n, Rewrite::NONE); @@ -444,8 +444,7 @@ BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const if (n[0].getKind() == MK_BAG) { // (bag.is_singleton (mkBag x c)) = (c == 1) - Node one = d_nm->mkConst(Rational(1)); - Node equal = n[0][1].eqNode(one); + Node equal = n[0][1].eqNode(d_one); return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG); } return BagsRewriteResponse(n, Rewrite::NONE); @@ -457,9 +456,8 @@ BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const if (n[0].getKind() == SINGLETON) { // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1) - Node one = d_nm->mkConst(Rational(1)); TypeNode type = n[0].getType().getSetElementType(); - Node bag = d_nm->mkBag(type, n[0][0], one); + Node bag = d_nm->mkBag(type, n[0][0], d_one); return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON); } return BagsRewriteResponse(n, Rewrite::NONE); @@ -484,20 +482,20 @@ BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const Assert(n.getKind() == kind::EQUAL); if (n[0] == n[1]) { - Node ret = NodeManager::currentNM()->mkConst(true); + Node ret = d_nm->mkConst(true); return BagsRewriteResponse(ret, Rewrite::EQ_REFL); } if (n[0].isConst() && n[1].isConst()) { - Node ret = NodeManager::currentNM()->mkConst(false); + Node ret = d_nm->mkConst(false); return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE); } // standard ordering if (n[0] > n[1]) { - Node ret = NodeManager::currentNM()->mkNode(kind::EQUAL, n[1], n[0]); + Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]); return BagsRewriteResponse(ret, Rewrite::EQ_SYM); } return BagsRewriteResponse(n, Rewrite::NONE); diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 8425e3b1f..fb76fb1c2 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -80,7 +80,7 @@ class BagsRewriter : public TheoryRewriter /** * rewrites for n include: * - (bag.count x emptybag) = 0 - * - (bag.count x (mkBag x c) = c where c > 0 is a constant + * - (bag.count x (bag x c) = c * - otherwise = n */ BagsRewriteResponse rewriteBagCount(const TNode& n) const; @@ -213,6 +213,8 @@ class BagsRewriter : public TheoryRewriter private: /** Reference to the rewriter statistics. */ NodeManager* d_nm; + Node d_zero; + Node d_one; /** Reference to the rewriter statistics. */ HistogramStat* d_statistics; }; /* class TheoryBagsRewriter */ diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp new file mode 100644 index 000000000..1244a43ac --- /dev/null +++ b/src/theory/bags/infer_info.cpp @@ -0,0 +1,111 @@ +/********************* */ +/*! \file infer_info.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of inference information utility. + **/ + +#include "theory/bags/infer_info.h" + +#include "theory/bags/inference_manager.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +const char* toString(Inference i) +{ + switch (i) + { + case Inference::NONE: return "NONE"; + case Inference::BAG_MK_BAG: return "BAG_MK_BAG"; + case Inference::BAG_EQUALITY: return "BAG_EQUALITY"; + case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY"; + case Inference::BAG_EMPTY: return "BAG_EMPTY"; + case Inference::BAG_UNION_DISJOINT: return "BAG_UNION_DISJOINT"; + case Inference::BAG_UNION_MAX: return "BAG_UNION_MAX"; + case Inference::BAG_INTERSECTION_MIN: return "BAG_INTERSECTION_MIN"; + case Inference::BAG_DIFFERENCE_SUBTRACT: return "BAG_DIFFERENCE_SUBTRACT"; + case Inference::BAG_DIFFERENCE_REMOVE: return "BAG_DIFFERENCE_REMOVE"; + case Inference::BAG_DUPLICATE_REMOVAL: return "BAG_DUPLICATE_REMOVAL"; + default: return "?"; + } +} + +std::ostream& operator<<(std::ostream& out, Inference i) +{ + out << toString(i); + return out; +} + +InferInfo::InferInfo() : d_id(Inference::NONE) {} + +bool InferInfo::process(TheoryInferenceManager* im, bool asLemma) +{ + Node lemma = d_conclusion; + if (d_premises.size() >= 2) + { + Node andNode = NodeManager::currentNM()->mkNode(kind::AND, d_premises); + lemma = andNode.impNode(lemma); + } + else if (d_premises.size() == 1) + { + lemma = d_premises[0].impNode(lemma); + } + if (asLemma) + { + TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); + return im->trustedLemma(trustedLemma); + } + Unimplemented(); +} + +bool InferInfo::isTrivial() const +{ + Assert(!d_conclusion.isNull()); + return d_conclusion.isConst() && d_conclusion.getConst(); +} + +bool InferInfo::isConflict() const +{ + Assert(!d_conclusion.isNull()); + return d_conclusion.isConst() && !d_conclusion.getConst(); +} + +bool InferInfo::isFact() const +{ + Assert(!d_conclusion.isNull()); + TNode atom = + d_conclusion.getKind() == kind::NOT ? d_conclusion[0] : d_conclusion; + return !atom.isConst() && atom.getKind() != kind::OR; +} + +Node InferInfo::getPremises() const +{ + // d_noExplain is a subset of d_ant + NodeManager* nm = NodeManager::currentNM(); + return nm->mkAnd(d_premises); +} + +std::ostream& operator<<(std::ostream& out, const InferInfo& ii) +{ + out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl; + if (!ii.d_premises.empty()) + { + out << " :premise (" << ii.d_premises << ")" << std::endl; + } + + out << ")"; + return out; +} + +} // namespace bags +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h new file mode 100644 index 000000000..3edbef737 --- /dev/null +++ b/src/theory/bags/infer_info.h @@ -0,0 +1,128 @@ +/********************* */ +/*! \file infer_info.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Inference information utility + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__BAGS__INFER_INFO_H +#define CVC4__THEORY__BAGS__INFER_INFO_H + +#include +#include + +#include "expr/node.h" +#include "theory/theory_inference.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +/** + * Types of inferences used in the procedure + */ +enum class Inference : uint32_t +{ + NONE, + BAG_MK_BAG, + BAG_EQUALITY, + BAG_DISEQUALITY, + BAG_EMPTY, + BAG_UNION_DISJOINT, + BAG_UNION_MAX, + BAG_INTERSECTION_MIN, + BAG_DIFFERENCE_SUBTRACT, + BAG_DIFFERENCE_REMOVE, + BAG_DUPLICATE_REMOVAL +}; + +/** + * Converts an inference to a string. Note: This function is also used in + * `safe_print()`. Changing this functions name or signature will result in + * `safe_print()` printing "" instead of the proper strings for + * the enum values. + * + * @param i The inference + * @return The name of the inference + */ +const char* toString(Inference i); + +/** + * Writes an inference name to a stream. + * + * @param out The stream to write to + * @param i The inference to write to the stream + * @return The stream + */ +std::ostream& operator<<(std::ostream& out, Inference i); + +class InferenceManager; + +/** + * An inference. This is a class to track an unprocessed call to either + * send a fact, lemma, or conflict that is waiting to be asserted to the + * equality engine or sent on the output channel. + */ +class InferInfo : public TheoryInference +{ + public: + InferInfo(); + ~InferInfo() {} + /** Process this inference */ + bool process(TheoryInferenceManager* im, bool asLemma) override; + /** The inference identifier */ + Inference d_id; + /** The conclusion */ + Node d_conclusion; + /** + * The premise(s) of the inference, interpreted conjunctively. These are + * literals that currently hold in the equality engine. + */ + std::vector d_premises; + + /** + * A list of new skolems introduced as a result of this inference. They + * are mapped to by a length status, indicating the length constraint that + * can be assumed for them. + */ + std::vector d_newSkolem; + /** Is this infer info trivial? True if d_conc is true. */ + bool isTrivial() const; + /** + * Does this infer info correspond to a conflict? True if d_conc is false + * and it has no new premises (d_noExplain). + */ + bool isConflict() const; + /** + * Does this infer info correspond to a "fact". A fact is an inference whose + * conclusion should be added as an equality or predicate to the equality + * engine with no new external premises (d_noExplain). + */ + bool isFact() const; + /** Get premises */ + Node getPremises() const; +}; + +/** + * Writes an inference info to a stream. + * + * @param out The stream to write to + * @param ii The inference info to write to the stream + * @return The stream + */ +std::ostream& operator<<(std::ostream& out, const InferInfo& ii); + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAGS__INFER_INFO_H */ diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp new file mode 100644 index 000000000..759ea1f0c --- /dev/null +++ b/src/theory/bags/inference_generator.cpp @@ -0,0 +1,273 @@ +/********************* */ +/*! \file inference_generator.cpp + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Inference generator utility + **/ + +#include "inference_generator.h" + +#include "expr/attribute.h" +#include "expr/bound_var_manager.h" +#include "expr/skolem_manager.h" +#include "theory/uf/equality_engine.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state) +{ + d_nm = NodeManager::currentNM(); + d_sm = d_nm->getSkolemManager(); + d_true = d_nm->mkConst(true); + d_zero = d_nm->mkConst(Rational(0)); + d_one = d_nm->mkConst(Rational(1)); +} + +InferInfo InferenceGenerator::mkBag(Node n, Node e) +{ + Assert(n.getKind() == kind::MK_BAG); + Assert(e.getType() == n.getType().getBagElementType()); + + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_MK_BAG; + Node count = getMultiplicitySkolem(e, n, inferInfo); + if (n[0] == e) + { + // TODO: refactor this with the rewriter + // (=> true (= (bag.count e (bag e c)) c)) + inferInfo.d_conclusion = count.eqNode(n[1]); + } + else + { + // (=> + // true + // (= (bag.count e (bag x c)) (ite (= e x) c 0))) + + Node same = d_nm->mkNode(kind::EQUAL, n[0], e); + Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero); + Node equal = count.eqNode(ite); + inferInfo.d_conclusion = equal; + } + return inferInfo; +} + +InferInfo InferenceGenerator::bagEquality(Node n, Node e) +{ + Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_EQUALITY; + inferInfo.d_premises.push_back(n); + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + + Node equal = countA.eqNode(countB); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +struct BagsDeqAttributeId +{ +}; +typedef expr::Attribute BagsDeqAttribute; + +InferInfo InferenceGenerator::bagDisequality(Node n) +{ + Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL); + Assert(n[0][0].getType().isBag()); + + Node A = n[0][0]; + Node B = n[0][1]; + + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_DISEQUALITY; + + TypeNode elementType = A.getType().getBagElementType(); + + BoundVarManager* bvm = d_nm->getBoundVarManager(); + Node element = bvm->mkBoundVar(n, elementType); + SkolemManager* sm = d_nm->getSkolemManager(); + Node skolem = + sm->mkSkolem(element, + n, + "bag_disequal", + "an extensional lemma for disequality of two bags"); + + inferInfo.d_newSkolem.push_back(skolem); + + Node countA = getMultiplicitySkolem(skolem, A, inferInfo); + Node countB = getMultiplicitySkolem(skolem, B, inferInfo); + + Node disEqual = countA.eqNode(countB).notNode(); + + inferInfo.d_premises.push_back(n); + inferInfo.d_conclusion = disEqual; + return inferInfo; +} + +InferInfo InferenceGenerator::bagEmpty(Node e) +{ + EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType())); + Node empty = d_nm->mkConst(emptyBag); + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_EMPTY; + Node count = getMultiplicitySkolem(e, empty, inferInfo); + + Node equal = count.eqNode(d_zero); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) +{ + Assert(n.getKind() == kind::UNION_DISJOINT && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_UNION_DISJOINT; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node sum = d_nm->mkNode(kind::PLUS, countA, countB); + Node equal = count.eqNode(sum); + + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::unionMax(Node n, Node e) +{ + Assert(n.getKind() == kind::UNION_MAX && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_UNION_MAX; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node gt = d_nm->mkNode(kind::GT, countA, countB); + Node max = d_nm->mkNode(kind::ITE, gt, countA, countB); + Node equal = count.eqNode(max); + + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::intersection(Node n, Node e) +{ + Assert(n.getKind() == kind::INTERSECTION_MIN && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_INTERSECTION_MIN; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node lt = d_nm->mkNode(kind::LT, countA, countB); + Node min = d_nm->mkNode(kind::ITE, lt, countA, countB); + Node equal = count.eqNode(min); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) +{ + Assert(n.getKind() == kind::DIFFERENCE_SUBTRACT && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node subtract = d_nm->mkNode(kind::MINUS, countA, countB); + Node gte = d_nm->mkNode(kind::GEQ, countA, countB); + Node difference = d_nm->mkNode(kind::ITE, gte, subtract, d_zero); + Node equal = count.eqNode(difference); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::differenceRemove(Node n, Node e) +{ + Assert(n.getKind() == kind::DIFFERENCE_REMOVE && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + Node B = n[1]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node countB = getMultiplicitySkolem(e, B, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero); + Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero); + Node equal = count.eqNode(difference); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) +{ + Assert(n.getKind() == kind::DUPLICATE_REMOVAL && n[0].getType().isBag()); + Assert(e.getType() == n[0].getType().getBagElementType()); + + Node A = n[0]; + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL; + + Node countA = getMultiplicitySkolem(e, A, inferInfo); + Node count = getMultiplicitySkolem(e, n, inferInfo); + + Node gte = d_nm->mkNode(kind::GEQ, countA, d_one); + Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero); + Node equal = count.eqNode(ite); + inferInfo.d_conclusion = equal; + return inferInfo; +} + +Node InferenceGenerator::getMultiplicitySkolem(Node element, + Node bag, + InferInfo& inferInfo) +{ + Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag); + Node skolem = d_state->registerBagElement(count); + eq::EqualityEngine* ee = d_state->getEqualityEngine(); + ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true)); + inferInfo.d_newSkolem.push_back(skolem); + return skolem; +} + +} // namespace bags +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h new file mode 100644 index 000000000..b56997088 --- /dev/null +++ b/src/theory/bags/inference_generator.h @@ -0,0 +1,172 @@ +/********************* */ +/*! \file inference_generator.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Inference generator utility + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H +#define CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H + +#include +#include + +#include "expr/node.h" +#include "infer_info.h" +#include "theory/bags/solver_state.h" + +namespace CVC4 { +namespace theory { +namespace bags { + +/** + * An inference generator class. This class is used by the core solver to + * generate lemmas + */ +class InferenceGenerator +{ + public: + InferenceGenerator(SolverState* state); + + /** + * @param n is (bag x c) of type (Bag E) + * @param e is a node of type E + * @return an inference that represents the following implication + * (=> + * true + * (= (bag.count e (bag x c)) (ite (= e x) c 0))) + */ + InferInfo mkBag(Node n, Node e); + + /** + * @param n is (= A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * (= A B) + * (= (count e A) (count e B))) + */ + InferInfo bagEquality(Node n, Node e); + /** + * @param n is (not (= A B)) where A, B are bags of type (Bag E) + * @return an inference that represents the following implication + * (=> + * (not (= A B)) + * (not (= (count e A) (count e B)))) + * where e is a fresh skolem of type E + */ + InferInfo bagDisequality(Node n); + /** + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= 0 (count e (as emptybag (Bag E))))) + */ + InferInfo bagEmpty(Node e); + /** + * @param n is (union_disjoint A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e k_{(union_disjoint A B)}) + * (+ (count e A) (count e B)))) + * where k_{(union_disjoint A B)} is a unique purification skolem + * for (union_disjoint A B) + */ + InferInfo unionDisjoint(Node n, Node e); + /** + * @param n is (union_disjoint A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e (union_max A B)) + * (ite + * (> (count e A) (count e B)) + * (count e A) + * (count e B))))) + */ + InferInfo unionMax(Node n, Node e); + /** + * @param n is (intersection_min A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e (intersection_min A B)) + * (ite( + * (< (count e A) (count e B)) + * (count e A) + * (count e B))))) + */ + InferInfo intersection(Node n, Node e); + /** + * @param n is (difference_subtract A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e (difference_subtract A B)) + * (ite + * (>= (count e A) (count e B)) + * (- (count e A) (count e B)) + * 0)))) + */ + InferInfo differenceSubtract(Node n, Node e); + /** + * @param n is (difference_remove A B) where A, B are bags of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e (difference_remove A B)) + * (ite + * (= (count e B) 0) + * (count e A) + * 0)))) + */ + InferInfo differenceRemove(Node n, Node e); + /** + * @param n is (duplicate_removal A) where A is a bag of type (Bag E) + * @param e is a node of Type E + * @return an inference that represents the following implication + * (=> + * true + * (= (count e (duplicate_removal A)) + * (ite (>= (count e A) 1) 1 0)))) + */ + InferInfo duplicateRemoval(Node n, Node e); + + /** + * @param element of type T + * @param bag of type (bag T) + * @param inferInfo to store new skolem + * @return a skolem for (bag.count element bag) + */ + Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo); + + private: + NodeManager* d_nm; + SkolemManager* d_sm; + SolverState* d_state; + Node d_true; + Node d_zero; + Node d_one; +}; + +} // namespace bags +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H */ diff --git a/src/theory/bags/inference_manager.cpp b/src/theory/bags/inference_manager.cpp index eb695725e..0ccd06922 100644 --- a/src/theory/bags/inference_manager.cpp +++ b/src/theory/bags/inference_manager.cpp @@ -30,6 +30,20 @@ InferenceManager::InferenceManager(Theory& t, d_false = NodeManager::currentNM()->mkConst(false); } +void InferenceManager::doPending() +{ + doPendingFacts(); + if (d_state.isInConflict()) + { + // just clear the pending vectors, nothing else to do + clearPendingLemmas(); + clearPendingPhaseRequirements(); + return; + } + doPendingLemmas(); + doPendingPhaseRequirements(); +} + } // namespace bags } // namespace theory } // namespace CVC4 diff --git a/src/theory/bags/inference_manager.h b/src/theory/bags/inference_manager.h index 90d188d33..67025548c 100644 --- a/src/theory/bags/inference_manager.h +++ b/src/theory/bags/inference_manager.h @@ -38,6 +38,16 @@ class InferenceManager : public InferenceManagerBuffered public: InferenceManager(Theory& t, SolverState& s, ProofNodeManager* pnm); + /** + * Do pending method. This processes all pending facts, lemmas and pending + * phase requests based on the policy of this manager. This means that + * we process the pending facts first and abort if in conflict. Otherwise, we + * process the pending lemmas and then the pending phase requirements. + * Notice that we process the pending lemmas even if there were facts. + */ + // TODO: refactor this before merge with theory of strings + void doPending(); + private: /** constants */ Node d_true; diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index 88c41a961..2a5c09125 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -158,7 +158,7 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n, remainderOfA(elements, elementsB, itB); Trace("bags-evaluate") << "elements: " << elements << std::endl; - Node bag = constructBagFromElements(n.getType(), elements); + Node bag = constructConstantBagFromElements(n.getType(), elements); Trace("bags-evaluate") << "bag: " << bag << std::endl; return bag; } @@ -187,7 +187,7 @@ std::map NormalForm::getBagElements(TNode n) return elements; } -Node NormalForm::constructBagFromElements( +Node NormalForm::constructConstantBagFromElements( TypeNode t, const std::map& elements) { Assert(t.isBag()); @@ -209,6 +209,26 @@ Node NormalForm::constructBagFromElements( return bag; } +Node NormalForm::constructBagFromElements(TypeNode t, + const std::map& elements) +{ + Assert(t.isBag()); + NodeManager* nm = NodeManager::currentNM(); + if (elements.empty()) + { + return nm->mkConst(EmptyBag(t)); + } + TypeNode elementType = t.getBagElementType(); + std::map::const_reverse_iterator it = elements.rbegin(); + Node bag = nm->mkBag(elementType, it->first, it->second); + while (++it != elements.rend()) + { + Node n = nm->mkBag(elementType, it->first, it->second); + bag = nm->mkNode(UNION_DISJOINT, n, bag); + } + return bag; +} + Node NormalForm::evaluateMakeBag(TNode n) { // the case where n is const should be handled earlier. @@ -262,7 +282,7 @@ Node NormalForm::evaluateDuplicateRemoval(TNode n) { it->second = one; } - Node bag = constructBagFromElements(n[0].getType(), newElements); + Node bag = constructConstantBagFromElements(n[0].getType(), newElements); return bag; } @@ -624,7 +644,7 @@ Node NormalForm::evaluateFromSet(TNode n) bagElements[element] = one; } TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); - Node bag = constructBagFromElements(bagType, bagElements); + Node bag = constructConstantBagFromElements(bagType, bagElements); return bag; } diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index e88fe6fef..acb13bb44 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -61,9 +61,19 @@ class NormalForm * multiplicities * @return a constant bag that contains */ - static Node constructBagFromElements( + static Node constructConstantBagFromElements( TypeNode t, const std::map& elements); + /** + * construct a constant bag from node elements + * @param t the type of the returned bag + * @param elements a map whose keys are constant elements and values are + * multiplicities + * @return a constant bag that contains + */ + static Node constructBagFromElements(TypeNode t, + const std::map& elements); + private: /** * a high order helper function that return a constant bag that is the result diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp index d58b68657..744f6de9f 100644 --- a/src/theory/bags/solver_state.cpp +++ b/src/theory/bags/solver_state.cpp @@ -14,6 +14,11 @@ #include "theory/bags/solver_state.h" +#include "expr/attribute.h" +#include "expr/bound_var_manager.h" +#include "expr/skolem_manager.h" +#include "theory/uf/equality_engine.h" + using namespace std; using namespace CVC4::kind; @@ -30,6 +35,51 @@ SolverState::SolverState(context::Context* c, d_false = NodeManager::currentNM()->mkConst(false); } +struct BagsCountAttributeId +{ +}; +typedef expr::Attribute BagsCountAttribute; + +void SolverState::registerClass(TNode n) +{ + TypeNode t = n.getType(); + if (!t.isBag()) + { + return; + } + d_bags.insert(n); +} + +Node SolverState::registerBagElement(TNode n) +{ + Assert(n.getKind() == BAG_COUNT); + Node element = n[0]; + TypeNode elementType = element.getType(); + Node bag = n[1]; + d_elements[elementType].insert(element); + NodeManager* nm = NodeManager::currentNM(); + BoundVarManager* bvm = nm->getBoundVarManager(); + Node multiplicity = bvm->mkBoundVar(n, nm->integerType()); + Node equal = n.eqNode(multiplicity); + SkolemManager* sm = nm->getSkolemManager(); + Node skolem = sm->mkSkolem( + multiplicity, + equal, + "bag_multiplicity", + "an extensional lemma for multiplicity of an element in a bag"); + d_count[bag][element] = skolem; + Trace("bags::SolverState::registerBagElement") + << "New skolem: " << skolem << " for " << n << std::endl; + + return skolem; +} + +std::set& SolverState::getBags() { return d_bags; } + +std::set& SolverState::getElements(TypeNode t) { return d_elements[t]; } + +std::map& SolverState::getBagElements(Node B) { return d_count[B]; } + } // namespace bags } // namespace theory } // namespace CVC4 diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h index 9a4bfdae7..8d70ee8f7 100644 --- a/src/theory/bags/solver_state.h +++ b/src/theory/bags/solver_state.h @@ -31,10 +31,24 @@ class SolverState : public TheoryState public: SolverState(context::Context* c, context::UserContext* u, Valuation val); + void registerClass(TNode n); + + Node registerBagElement(TNode n); + + std::set& getBags(); + + std::set& getElements(TypeNode t); + + std::map& getBagElements(Node B); + private: /** constants */ Node d_true; Node d_false; + std::set d_bags; + std::map> d_elements; + /** bag -> element -> multiplicity */ + std::map> d_count; }; /* class SolverState */ } // namespace bags diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 6ba1ad87a..b3251e464 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -14,6 +14,8 @@ #include "theory/bags/theory_bags.h" +#include "theory/theory_model.h" + using namespace CVC4::kind; namespace CVC4 { @@ -28,10 +30,13 @@ TheoryBags::TheoryBags(context::Context* c, ProofNodeManager* pnm) : Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm), d_state(c, u, valuation), - d_im(*this, d_state, pnm), + d_im(*this, d_state, nullptr), + d_ig(&d_state), d_notify(*this, d_im), d_statistics(), - d_rewriter(&d_statistics.d_rewrites) + d_rewriter(&d_statistics.d_rewrites), + d_termReg(d_state, d_im), + d_solver(d_state, d_im, d_termReg) { // use the official theory state and inference manager objects d_theoryState = &d_state; @@ -70,7 +75,60 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(BAG_TO_SET); } -void TheoryBags::postCheck(Effort level) {} +void TheoryBags::postCheck(Effort effort) +{ + d_im.doPendingFacts(); + // TODO: clean this before merge Assert(d_strat.isStrategyInit()); + if (!d_state.isInConflict() && !d_valuation.needCheck()) + // TODO: clean this before merge && d_strat.hasStrategyEffort(e)) + { + Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl; + + // TODO: clean this before merge ++(d_statistics.d_checkRuns); + bool sentLemma = false; + bool hadPending = false; + Trace("bags-check") << "Full effort check..." << std::endl; + do + { + d_im.reset(); + // TODO: clean this before merge ++(d_statistics.d_strategyRuns); + Trace("bags-check") << " * Run strategy..." << std::endl; + // TODO: clean this before merge runStrategy(e); + + d_solver.postCheck(); + + // remember if we had pending facts or lemmas + hadPending = d_im.hasPending(); + // Send the facts *and* the lemmas. We send lemmas regardless of whether + // we send facts since some lemmas cannot be dropped. Other lemmas are + // otherwise avoided by aborting the strategy when a fact is ready. + d_im.doPending(); + // Did we successfully send a lemma? Notice that if hasPending = true + // and sentLemma = false, then the above call may have: + // (1) had no pending lemmas, but successfully processed pending facts, + // (2) unsuccessfully processed pending lemmas. + // In either case, we repeat the strategy if we are not in conflict. + sentLemma = d_im.hasSentLemma(); + if (Trace.isOn("bags-check")) + { + // TODO: clean this Trace("bags-check") << " ...finish run strategy: "; + Trace("bags-check") << (hadPending ? "hadPending " : ""); + Trace("bags-check") << (sentLemma ? "sentLemma " : ""); + Trace("bags-check") << (d_state.isInConflict() ? "conflict " : ""); + if (!hadPending && !sentLemma && !d_state.isInConflict()) + { + Trace("bags-check") << "(none)"; + } + Trace("bags-check") << std::endl; + } + // repeat if we did not add a lemma or conflict, and we had pending + // facts or lemmas. + } while (!d_state.isInConflict() && !sentLemma && hadPending); + } + Trace("bags-check") << "Theory of bags, done check : " << effort << std::endl; + Assert(!d_im.hasPendingFact()); + Assert(!d_im.hasPendingLemma()); +} void TheoryBags::notifyFact(TNode atom, bool polarity, @@ -80,8 +138,43 @@ void TheoryBags::notifyFact(TNode atom, } bool TheoryBags::collectModelValues(TheoryModel* m, - const std::set& termBag) + const std::set& termSet) { + Trace("bags-model") << "TheoryBags : Collect model values" << std::endl; + + Trace("bags-model") << "Term set: " << termSet << std::endl; + + // get the relevant bag equivalence classes + for (const Node& n : termSet) + { + TypeNode tn = n.getType(); + if (!tn.isBag()) + { + continue; + } + Node r = d_state.getRepresentative(n); + std::map elements = d_state.getBagElements(r); + Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl + << elements << std::endl; + std::map elementReps; + for (std::pair pair : elements) + { + Node key = d_state.getRepresentative(pair.first); + Node value = d_state.getRepresentative(pair.second); + elementReps[key] = value; + } + Node rep = NormalForm::constructBagFromElements(tn, elementReps); + rep = Rewriter::rewrite(rep); + + Trace("bags-model") << "rep of " << n << " is: " << rep << std::endl; + for (std::pair pair : elementReps) + { + m->assertSkeleton(pair.first); + m->assertSkeleton(pair.second); + } + m->assertEquality(rep, n, true); + m->assertSkeleton(rep); + } return true; } @@ -101,41 +194,73 @@ void TheoryBags::presolve() {} /**************************** eq::NotifyClass *****************************/ -void TheoryBags::eqNotifyNewClass(TNode t) +void TheoryBags::eqNotifyNewClass(TNode n) { - Assert(false) << "Not implemented yet" << std::endl; + Kind k = n.getKind(); + d_state.registerClass(n); + if (n.getKind() == MK_BAG) + { + // TODO: refactor this before merge + /* + * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m)) + * where s is a fresh skolem variable + */ + NodeManager* nm = NodeManager::currentNM(); + Node count = nm->mkNode(BAG_COUNT, n[0], n); + Node skolem = d_state.registerBagElement(count); + Node countSkolem = count.eqNode(skolem); + Node skolemMultiplicity = n[1].eqNode(skolem); + Node lemma = countSkolem.andNode(skolemMultiplicity); + TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); + d_im.trustedLemma(trustedLemma); + } + if (k == BAG_COUNT) + { + /* + * (count x A) generates the lemma (= s (count x A)) + * where s is a fresh skolem variable + */ + Node skolem = d_state.registerBagElement(n); + Node lemma = n.eqNode(skolem); + TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); + d_im.trustedLemma(trustedLemma); + } } -void TheoryBags::eqNotifyMerge(TNode t1, TNode t2) -{ - Assert(false) << "Not implemented yet" << std::endl; -} +void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {} -void TheoryBags::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) +void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) { - Assert(false) << "Not implemented yet" << std::endl; + TypeNode t1 = n1.getType(); + if (t1.isBag()) + { + InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode()); + Node lemma = reason.impNode(info.d_conclusion); + TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); + d_im.trustedLemma(trustedLemma); + } } -void TheoryBags::NotifyClass::eqNotifyNewClass(TNode t) +void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n) { Debug("bags-eq") << "[bags-eq] eqNotifyNewClass:" - << " t = " << t << std::endl; - d_theory.eqNotifyNewClass(t); + << " n = " << n << std::endl; + d_theory.eqNotifyNewClass(n); } -void TheoryBags::NotifyClass::eqNotifyMerge(TNode t1, TNode t2) +void TheoryBags::NotifyClass::eqNotifyMerge(TNode n1, TNode n2) { Debug("bags-eq") << "[bags-eq] eqNotifyMerge:" - << " t1 = " << t1 << " t2 = " << t2 << std::endl; - d_theory.eqNotifyMerge(t1, t2); + << " n1 = " << n1 << " n2 = " << n2 << std::endl; + d_theory.eqNotifyMerge(n1, n2); } -void TheoryBags::NotifyClass::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) +void TheoryBags::NotifyClass::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) { Debug("bags-eq") << "[bags-eq] eqNotifyDisequal:" - << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason + << " n1 = " << n1 << " n2 = " << n2 << " reason = " << reason << std::endl; - d_theory.eqNotifyDisequal(t1, t2, reason); + d_theory.eqNotifyDisequal(n1, n2, reason); } } // namespace bags diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h index 03676c2d2..09784add9 100644 --- a/src/theory/bags/theory_bags.h +++ b/src/theory/bags/theory_bags.h @@ -19,9 +19,11 @@ #include +#include "theory/bags/bag_solver.h" #include "theory/bags/bags_rewriter.h" #include "theory/bags/bags_statistics.h" #include "theory/bags/inference_manager.h" +#include "theory/bags/inference_generator.h" #include "theory/bags/solver_state.h" #include "theory/theory.h" #include "theory/theory_eq_notify.h" @@ -58,7 +60,7 @@ class TheoryBags : public Theory //--------------------------------- standard check /** Post-check, called after the fact queue of the theory is processed. */ - void postCheck(Effort level) override; + void postCheck(Effort effort) override; /** Notify fact */ void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal) override; //--------------------------------- end standard check @@ -82,9 +84,9 @@ class TheoryBags : public Theory : TheoryEqNotifyClass(inferenceManager), d_theory(theory) { } - void eqNotifyNewClass(TNode t) override; - void eqNotifyMerge(TNode t1, TNode t2) override; - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override; + void eqNotifyNewClass(TNode n) override; + void eqNotifyMerge(TNode n1, TNode n2) override; + void eqNotifyDisequal(TNode n1, TNode n2, TNode reason) override; private: TheoryBags& d_theory; @@ -94,15 +96,21 @@ class TheoryBags : public Theory SolverState d_state; /** The inference manager */ InferenceManager d_im; + /** The inference generator */ + InferenceGenerator d_ig; /** Instance of the above class */ NotifyClass d_notify; /** Statistics for the theory of bags. */ BagsStatistics d_statistics; /** The theory rewriter for this theory. */ BagsRewriter d_rewriter; + /** The term registry for this theory */ + TermRegistry d_termReg; + /** the main solver for bags */ + BagSolver d_solver; - void eqNotifyNewClass(TNode t); - void eqNotifyMerge(TNode t1, TNode t2); + void eqNotifyNewClass(TNode n); + void eqNotifyMerge(TNode n1, TNode n2); void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); }; /* class TheoryBags */ diff --git a/src/theory/strings/base_solver.h b/src/theory/strings/base_solver.h index 929034e42..66ec35277 100644 --- a/src/theory/strings/base_solver.h +++ b/src/theory/strings/base_solver.h @@ -225,7 +225,7 @@ class BaseSolver * * This set contains a set of nodes that are not representatives of their * congruence class. This set is used to skip reasoning about terms in - * various inference schemas implemnted by this class. + * various inference schemas implemented by this class. */ NodeSet d_congruent; /** diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 72757ef32..7ceb3b4b2 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1415,6 +1415,12 @@ set(regress_1_tests regress1/bug681.smt2 regress1/bug694-Unapply1.scala-0.smt2 regress1/bug800.smt2 + regress1/bags/disequality.smt2 + regress1/bags/subbag1.smt2 + regress1/bags/subbag2.smt2 + regress1/bags/union_disjoint.smt2 + regress1/bags/union_max1.smt2 + regress1/bags/union_max2.smt2 regress1/bv/bench_38.delta.smt2 regress1/bv/bug787.smt2 regress1/bv/bug_extract_mult_leading_bit.smt2 diff --git a/test/regress/regress1/bags/disequality.smt2 b/test/regress/regress1/bags/disequality.smt2 new file mode 100644 index 000000000..e49935fac --- /dev/null +++ b/test/regress/regress1/bags/disequality.smt2 @@ -0,0 +1,14 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun C () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (distinct A B C)) +(assert (> (bag.count x A) 0)) +(assert (> (bag.count y B) 0)) +(assert (= (bag.count x A) (bag.count x B))) +(assert (= (bag.count y A) (bag.count y B))) +(assert (distinct x y)) +(check-sat) \ No newline at end of file diff --git a/test/regress/regress1/bags/subbag1.smt2 b/test/regress/regress1/bags/subbag1.smt2 new file mode 100644 index 000000000..055e47a17 --- /dev/null +++ b/test/regress/regress1/bags/subbag1.smt2 @@ -0,0 +1,11 @@ +(set-logic ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(assert (= x 1)) +(assert (subbag A B)) +(assert (subbag B A)) +(assert (= (bag.count x A) 5)) +(assert (= (bag.count x B) 10)) +(check-sat) \ No newline at end of file diff --git a/test/regress/regress1/bags/subbag2.smt2 b/test/regress/regress1/bags/subbag2.smt2 new file mode 100644 index 000000000..6d5cde362 --- /dev/null +++ b/test/regress/regress1/bags/subbag2.smt2 @@ -0,0 +1,13 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (subbag A B)) +(assert (subbag B A)) +(assert (= (bag.count x A) x)) +(assert (= (bag.count y A) x)) +(assert (distinct x y)) +(assert (= (bag.count x B) (+ 1 y))) +(check-sat) \ No newline at end of file diff --git a/test/regress/regress1/bags/union_disjoint.smt2 b/test/regress/regress1/bags/union_disjoint.smt2 new file mode 100644 index 000000000..d30ed4b14 --- /dev/null +++ b/test/regress/regress1/bags/union_disjoint.smt2 @@ -0,0 +1,10 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (= A (union_disjoint (bag x 1) (bag y 2)))) +(assert (= A (union_disjoint B (bag y 2)))) +(assert (= x y)) +(check-sat) diff --git a/test/regress/regress1/bags/union_max1.smt2 b/test/regress/regress1/bags/union_max1.smt2 new file mode 100644 index 000000000..d278527b9 --- /dev/null +++ b/test/regress/regress1/bags/union_max1.smt2 @@ -0,0 +1,10 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (= A (union_max (bag x 1) (bag y 2)))) +(assert (= A (union_disjoint B (bag y 2)))) +(assert (= x y)) +(check-sat) \ No newline at end of file diff --git a/test/regress/regress1/bags/union_max2.smt2 b/test/regress/regress1/bags/union_max2.smt2 new file mode 100644 index 000000000..dd4bceff5 --- /dev/null +++ b/test/regress/regress1/bags/union_max2.smt2 @@ -0,0 +1,11 @@ +(set-logic ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (= A (union_max (bag x 1) (bag y 2)))) +(assert (= A (union_disjoint B (bag y 2)))) +(assert (= x y)) +(assert (distinct (as emptybag (Bag Int)) B)) +(check-sat) \ No newline at end of file