This PR adds inferences for operators: intersection, duplicate_removal, and empty bags during post check.
It also fixes a bug in SolverState::getElements
namespace bags {
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
- : d_state(s), d_im(im), d_termReg(tr)
+ : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr)
{
d_zero = NodeManager::currentNM()->mkConst(Rational(0));
d_one = NodeManager::currentNM()->mkConst(Rational(1));
{
d_state.initialize();
+ checkDisequalBagTerms();
+
// At this point, all bag and count representatives should be in the solver
// state.
for (const Node& bag : d_state.getBags())
Kind k = n.getKind();
switch (k)
{
+ case kind::EMPTYBAG: checkEmpty(n); break;
case kind::MK_BAG: checkMkBag(n); break;
case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
case kind::UNION_MAX: checkUnionMax(n); break;
+ case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+ case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
default: break;
}
it++;
return elements;
}
+void BagSolver::checkEmpty(const Node& n)
+{
+ Assert(n.getKind() == EMPTYBAG);
+ for (const Node& e : d_state.getElements(n))
+ {
+ InferInfo i = d_ig.empty(n, e);
+ i.process(&d_im, true);
+ }
+}
+
void BagSolver::checkUnionDisjoint(const Node& n)
{
Assert(n.getKind() == UNION_DISJOINT);
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionDisjoint(n, e);
+ InferInfo i = d_ig.unionDisjoint(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionMax(n, e);
+ InferInfo i = d_ig.unionMax(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkIntersectionMin(const Node& n)
+{
+ Assert(n.getKind() == INTERSECTION_MIN);
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.intersection(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceSubtract(n, e);
+ InferInfo i = d_ig.differenceSubtract(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
<< " are: " << d_state.getElements(n) << std::endl;
for (const Node& e : d_state.getElements(n))
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.mkBag(n, e);
+ InferInfo i = d_ig.mkBag(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.nonNegativeCount(bag, element);
+ InferInfo i = d_ig.nonNegativeCount(bag, element);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
void BagSolver::checkDifferenceRemove(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceRemove(n, e);
+ InferInfo i = d_ig.differenceRemove(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
+ }
+}
+
+void BagSolver::checkDuplicateRemoval(Node n)
+{
+ Assert(n.getKind() == DUPLICATE_REMOVAL);
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards = d_state.getElements(n[0]);
+
+ elements.insert(downwards.begin(), downwards.end());
+ elements.insert(upwards.begin(), upwards.end());
+
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.duplicateRemoval(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkDisequalBagTerms()
+{
+ for (const Node& n : d_state.getDisequalBagTerms())
+ {
+ InferInfo info = d_ig.bagDisequality(n);
+ info.process(&d_im, true);
}
}
#include "context/cdhashset.h"
#include "context/cdlist.h"
#include "theory/bags/infer_info.h"
+#include "theory/bags/inference_generator.h"
#include "theory/bags/inference_manager.h"
#include "theory/bags/normal_form.h"
#include "theory/bags/solver_state.h"
void postCheck();
private:
+ /** apply inference rules for empty bags */
+ void checkEmpty(const Node& n);
/**
* apply inference rules for MK_BAG operator.
* Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
void checkUnionDisjoint(const Node& n);
/** apply inference rules for union max */
void checkUnionMax(const Node& n);
+ /** apply inference rules for intersection_min operator */
+ void checkIntersectionMin(const Node& n);
/** apply inference rules for difference subtract */
void checkDifferenceSubtract(const Node& n);
/** apply inference rules for difference remove */
void checkDifferenceRemove(const Node& n);
+ /** apply inference rules for duplicate removal operator */
+ void checkDuplicateRemoval(Node n);
/** apply non negative constraints for multiplicities */
void checkNonNegativeCountTerms(const Node& bag, const Node& element);
+ /** apply inference rules for disequal bag terms */
+ void checkDisequalBagTerms();
/** The solver state object */
SolverState& d_state;
+ /** The inference generator object*/
+ InferenceGenerator d_ig;
/** Reference to the inference manager for the theory of bags */
InferenceManager& d_im;
/** Reference to the term registry of theory of bags */
TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
im->trustedLemma(trustedLemma);
}
+
+ Trace("bags::InferInfo::process") << (*this) << std::endl;
+
return true;
}
};
typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
+InferInfo InferenceGenerator::bagDisequality(Node n)
{
- Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
- Assert(n[0][0].getType().isBag());
+ Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
- Node A = n[0][0];
- Node B = n[0][1];
+ Node A = n[0];
+ Node B = n[1];
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DISEQUALITY;
- inferInfo.d_premises.push_back(reason);
TypeNode elementType = A.getType().getBagElementType();
BoundVarManager* bvm = d_nm->getBoundVarManager();
Node disEqual = countA.eqNode(countB).notNode();
- inferInfo.d_premises.push_back(n);
+ inferInfo.d_premises.push_back(n.notNode());
inferInfo.d_conclusion = disEqual;
return inferInfo;
}
return skolem;
}
-InferInfo InferenceGenerator::bagEmpty(Node e)
+InferInfo InferenceGenerator::empty(Node n, Node e)
{
- EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
- Node empty = d_nm->mkConst(emptyBag);
+ Assert(n.getKind() == kind::EMPTYBAG);
+ Assert(e.getType() == n.getType().getBagElementType());
+
InferInfo inferInfo;
+ Node skolem = getSkolem(n, inferInfo);
inferInfo.d_id = Inference::BAG_EMPTY;
- Node count = getMultiplicityTerm(e, empty);
+ Node count = getMultiplicityTerm(e, skolem);
Node equal = count.eqNode(d_zero);
inferInfo.d_conclusion = equal;
*/
InferInfo mkBag(Node n, Node e);
/**
- * @param n is (not (= A B)) where A, B are bags of type (Bag E)
+ * @param n is (= A B) where A, B are bags of type (Bag E), and
+ * (not (= A B)) is an assertion in the equality engine
* @return an inference that represents the following implication
* (=>
* (not (= A B))
* (not (= (count e A) (count e B))))
* where e is a fresh skolem of type E.
*/
- InferInfo bagDisequality(Node n, Node reason);
+ InferInfo bagDisequality(Node n);
/**
+ * @param n is (as emptybag (Bag E))
* @param e is a node of Type E
* @return an inference that represents the following implication
* (=>
* true
- * (= 0 (count e (as emptybag (Bag E)))))
+ * (= 0 (count e skolem)))
+ * where skolem = (as emptybag (Bag String))
*/
- InferInfo bagEmpty(Node e);
+ InferInfo empty(Node n, Node e);
/**
* @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
* @param e is a node of Type E
const std::set<Node>& SolverState::getElements(Node B)
{
Node bag = getRepresentative(B);
- return d_bagElements[B];
+ return d_bagElements[bag];
}
+const std::set<Node>& SolverState::getDisequalBagTerms() { return d_deq; }
+
void SolverState::reset()
{
d_bagElements.clear();
d_bags.clear();
+ d_deq.clear();
}
void SolverState::initialize()
{
reset();
collectBagsAndCountTerms();
+ collectDisequalBagTerms();
}
void SolverState::collectBagsAndCountTerms()
{
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms start" << endl;
eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
while (!repIt.isFinished())
{
Node eqc = (*repIt);
- Trace("SolverState::collectBagsAndCountTerms")
- << "[" << eqc << "]: " << endl;
+ Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { ";
if (eqc.getType().isBag())
{
while (!it.isFinished())
{
Node n = (*it);
+ Trace("bags-eqc") << (*it) << " ";
Kind k = n.getKind();
if (k == MK_BAG)
{
}
++it;
}
-
+ Trace("bags-eqc") << " } " << std::endl;
++repIt;
}
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms end" << endl;
+ Trace("bags-eqc") << "bag representatives: " << d_bags << endl;
+ Trace("bags-eqc") << "bag elements: " << d_bagElements << endl;
+}
+
+void SolverState::collectDisequalBagTerms()
+{
+ eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee);
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ if (n.getKind() == EQUAL && n[0].getType().isBag())
+ {
+ Trace("bags-eqc") << "Disequal terms: " << n << std::endl;
+ d_deq.insert(n);
+ }
+ ++it;
+ }
}
} // namespace bags
const std::set<Node>& getElements(Node B);
/** initialize bag and count terms */
void initialize();
+ /** return disequal bag terms */
+ const std::set<Node>& getDisequalBagTerms();
private:
/** clear all bags data structures */
void reset();
- /** collect bags' representatives and all count terms.
- * This function is called during postCheck */
+ /**
+ * collect bags' representatives and all count terms.
+ * This function is called during postCheck
+ */
void collectBagsAndCountTerms();
+ /**
+ * collect disequal bag terms. This function is called during postCheck.
+ */
+ void collectDisequalBagTerms();
/** constants */
Node d_true;
Node d_false;
std::set<Node> d_bags;
/** bag -> associated elements */
std::map<Node, std::set<Node>> d_bagElements;
+ /** Disequal bag terms */
+ std::set<Node> d_deq;
}; /* class SolverState */
} // namespace bags
case BAG_FROM_SET:
case BAG_TO_SET:
case BAG_IS_SINGLETON:
- case DUPLICATE_REMOVAL:
{
std::stringstream ss;
ss << "Term of kind " << n.getKind() << " is not supported yet";
void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
-void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
-{
- TypeNode t1 = n1.getType();
- if (t1.isBag())
- {
- InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
- info.process(d_inferManager, true);
- }
-}
+void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) {}
void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n)
{
regress1/bug800.smt2
regress1/bags/difference_remove1.smt2
regress1/bags/disequality.smt2
+ regress1/bags/duplicate_removal1.smt2
+ regress1/bags/duplicate_removal2.smt2
+ regress1/bags/emptybag1.smt2
+ regress1/bags/intersection_min1.smt2
+ regress1/bags/intersection_min2.smt2
regress1/bags/issue5759.smt2
regress1/bags/subbag1.smt2
regress1/bags/subbag2.smt2
--- /dev/null
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(check-sat)
--- /dev/null
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(assert (= B (union_max A B)))
+(check-sat)
\ No newline at end of file
--- /dev/null
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun x () String)
+(declare-fun y () Int)
+(assert (= x "x"))
+(assert (= A (as emptybag (Bag String))))
+(assert (= (bag.count x A) y))
+(assert(> y 1))
+(check-sat)
--- /dev/null
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(assert (distinct A B C))
+(check-sat)
\ No newline at end of file
--- /dev/null
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (= C (union_disjoint A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(check-sat)