void BagSolver::postCheck()
{
+ d_state.initialize();
+
+ // At this point, all bag and count representatives should be in the solver
+ // state.
+ for (const Node& bag : d_state.getBags())
+ {
+ // iterate through all bags terms in each equivalent class
+ eq::EqClassIterator it =
+ eq::EqClassIterator(bag, d_state.getEqualityEngine());
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ Kind k = n.getKind();
+ switch (k)
+ {
+ case kind::MK_BAG: checkMkBag(n); break;
+ case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
+ case kind::UNION_MAX: checkUnionMax(n); break;
+ case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
+ case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+ default: break;
+ }
+ it++;
+ }
+ }
+
+ // add non negative constraints for all multiplicities
for (const Node& n : d_state.getBags())
{
- Kind k = n.getKind();
- switch (k)
+ for (const Node& e : d_state.getElements(n))
{
- case kind::MK_BAG: checkMkBag(n); break;
- case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
- case kind::UNION_MAX: checkUnionMax(n); break;
- case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
- default: break;
+ checkNonNegativeCountTerms(n, e);
}
}
}
+set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
+{
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards0 = d_state.getElements(n[0]);
+ const set<Node>& upwards1 = d_state.getElements(n[1]);
+
+ set_union(downwards.begin(),
+ downwards.end(),
+ upwards0.begin(),
+ upwards0.end(),
+ inserter(elements, elements.begin()));
+ elements.insert(upwards1.begin(), upwards1.end());
+ return elements;
+}
+
void BagSolver::checkUnionDisjoint(const Node& n)
{
Assert(n.getKind() == UNION_DISJOINT);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.unionDisjoint(n, e);
void BagSolver::checkUnionMax(const Node& n)
{
Assert(n.getKind() == UNION_MAX);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.unionMax(n, e);
void BagSolver::checkDifferenceSubtract(const Node& n)
{
Assert(n.getKind() == DIFFERENCE_SUBTRACT);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.differenceSubtract(n, e);
Trace("bags::BagSolver::postCheck") << i << endl;
}
}
+
void BagSolver::checkMkBag(const Node& n)
{
Assert(n.getKind() == MK_BAG);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ Trace("bags::BagSolver::postCheck")
+ << "BagSolver::checkMkBag Elements of " << n
+ << " are: " << d_state.getElements(n) << std::endl;
+ for (const Node& e : d_state.getElements(n))
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.mkBag(n, e);
Trace("bags::BagSolver::postCheck") << i << endl;
}
}
+void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
+{
+ InferenceGenerator ig(&d_state);
+ InferInfo i = ig.nonNegativeCount(bag, element);
+ i.process(&d_im, true);
+ Trace("bags::BagSolver::postCheck") << i << endl;
+}
+
+void BagSolver::checkDifferenceRemove(const Node& n)
+{
+ Assert(n.getKind() == DIFFERENCE_REMOVE);
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
+ {
+ InferenceGenerator ig(&d_state);
+ InferInfo i = ig.differenceRemove(n, e);
+ i.process(&d_im, true);
+ Trace("bags::BagSolver::postCheck") << i << endl;
+ }
+}
} // namespace bags
} // namespace theory
void postCheck();
private:
- /** apply inference rules for MK_BAG operator */
+ /**
+ * apply inference rules for MK_BAG operator.
+ * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
+ * and (bag.count y n).
+ * This function will add inferences for the count terms as documented in
+ * InferenceGenerator::mkBag.
+ * Note that element y may not be in bag n. See the documentation of
+ * SolverState::getElements.
+ */
void checkMkBag(const Node& n);
+ /**
+ * @param n is a bag that has the form (op A B)
+ * @return the set union of known elements in (op A B) , A, and B.
+ */
+ std::set<Node> getElementsForBinaryOperator(const Node& n);
/** apply inference rules for union disjoint */
void checkUnionDisjoint(const Node& n);
/** apply inference rules for union max */
void checkUnionMax(const Node& n);
/** apply inference rules for difference subtract */
void checkDifferenceSubtract(const Node& n);
+ /** apply inference rules for difference remove */
+ void checkDifferenceRemove(const Node& n);
+ /** apply non negative constraints for multiplicities */
+ void checkNonNegativeCountTerms(const Node& bag, const Node& element);
/** The solver state object */
SolverState& d_state;
/**
* rewrites for n include:
- * - (mkBag x 0) = (emptybag T) where T is the type of x
- * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
+ * - (bag x 0) = (emptybag T) where T is the type of x
+ * - (bag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
* constant
* - otherwise = n
*/
/**
* rewrites for n include:
- * - (duplicate_removal (mkBag x n)) = (mkBag x 1)
+ * - (duplicate_removal (bag x n)) = (bag x 1)
* where n is a positive constant
*/
BagsRewriteResponse rewriteDuplicateRemoval(const TNode& n) const;
BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.choose (mkBag x c)) = x where c is a constant > 0
+ * - (bag.choose (bag x c)) = x where c is a constant > 0
* - otherwise = n
*/
BagsRewriteResponse rewriteChoose(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.card (mkBag x c)) = c where c is a constant > 0
+ * - (bag.card (bag x c)) = c where c is a constant > 0
* - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
* - otherwise = n
*/
/**
* rewrites for n include:
- * - (bag.is_singleton (mkBag x c)) = (c == 1)
+ * - (bag.is_singleton (bag x c)) = (c == 1)
*/
BagsRewriteResponse rewriteIsSingleton(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
+ * - (bag.from_set (singleton (singleton_op Int) x)) = (bag x 1)
*/
BagsRewriteResponse rewriteFromSet(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
+ * - (bag.to_set (bag x n)) = (singleton (singleton_op T) x)
* where n is a positive constant and T is the type of the bag's elements
*/
BagsRewriteResponse rewriteToSet(const TNode& n) const;
switch (i)
{
case Inference::NONE: return "NONE";
+ case Inference::BAG_NON_NEGATIVE_COUNT: return "BAG_NON_NEGATIVE_COUNT";
+ case Inference::BAG_MK_BAG_SAME_ELEMENT: return "BAG_MK_BAG_SAME_ELEMENT";
case Inference::BAG_MK_BAG: return "BAG_MK_BAG";
case Inference::BAG_EQUALITY: return "BAG_EQUALITY";
case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY";
if (asLemma)
{
TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- return im->trustedLemma(trustedLemma);
+ im->trustedLemma(trustedLemma);
}
- Unimplemented();
+ else
+ {
+ Unimplemented();
+ }
+ for (const auto& pair : d_skolems)
+ {
+ Node n = pair.first.eqNode(pair.second);
+ TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
+ im->trustedLemma(trustedLemma);
+ }
+ return true;
}
bool InferInfo::isTrivial() const
return !atom.isConst() && atom.getKind() != kind::OR;
}
-Node InferInfo::getPremises() const
-{
- // d_noExplain is a subset of d_ant
- NodeManager* nm = NodeManager::currentNM();
- return nm->mkAnd(d_premises);
-}
-
std::ostream& operator<<(std::ostream& out, const InferInfo& ii)
{
- out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl;
+ out << "(infer :id " << ii.d_id << std::endl;
+ out << ":conclusion " << ii.d_conclusion << std::endl;
if (!ii.d_premises.empty())
{
out << " :premise (" << ii.d_premises << ")" << std::endl;
}
-
+ out << ":skolems " << ii.d_skolems << std::endl;
out << ")";
return out;
}
enum class Inference : uint32_t
{
NONE,
+ BAG_NON_NEGATIVE_COUNT,
+ BAG_MK_BAG_SAME_ELEMENT,
BAG_MK_BAG,
BAG_EQUALITY,
BAG_DISEQUALITY,
bool process(TheoryInferenceManager* im, bool asLemma) override;
/** The inference identifier */
Inference d_id;
- /** The conclusion */
+ /** The conclusions */
Node d_conclusion;
/**
* The premise(s) of the inference, interpreted conjunctively. These are
std::vector<Node> d_premises;
/**
- * A list of new skolems introduced as a result of this inference. They
- * are mapped to by a length status, indicating the length constraint that
- * can be assumed for them.
+ * A map of nodes to their skolem variables introduced as a result of this
+ * inference.
*/
- std::vector<Node> d_newSkolem;
+ std::map<Node, Node> d_skolems;
/** Is this infer info trivial? True if d_conc is true. */
bool isTrivial() const;
/**
* engine with no new external premises (d_noExplain).
*/
bool isFact() const;
- /** Get premises */
- Node getPremises() const;
};
/**
d_one = d_nm->mkConst(Rational(1));
}
+InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
+{
+ Assert(n.getType().isBag());
+ Assert(e.getType() == n.getType().getBagElementType());
+
+ InferInfo inferInfo;
+ inferInfo.d_id = Inference::BAG_NON_NEGATIVE_COUNT;
+ Node count = d_nm->mkNode(kind::BAG_COUNT, e, n);
+
+ Node gte = d_nm->mkNode(kind::GEQ, count, d_zero);
+ inferInfo.d_conclusion = gte;
+ return inferInfo;
+}
+
InferInfo InferenceGenerator::mkBag(Node n, Node e)
{
Assert(n.getKind() == kind::MK_BAG);
Assert(e.getType() == n.getType().getBagElementType());
InferInfo inferInfo;
- inferInfo.d_id = Inference::BAG_MK_BAG;
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
if (n[0] == e)
{
- // TODO: refactor this with the rewriter
+ // TODO issue #78: refactor this with BagRewriter
// (=> true (= (bag.count e (bag e c)) c))
+ inferInfo.d_id = Inference::BAG_MK_BAG_SAME_ELEMENT;
inferInfo.d_conclusion = count.eqNode(n[1]);
}
else
// (=>
// true
// (= (bag.count e (bag x c)) (ite (= e x) c 0)))
-
+ inferInfo.d_id = Inference::BAG_MK_BAG;
Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero);
Node equal = count.eqNode(ite);
return inferInfo;
}
-InferInfo InferenceGenerator::bagEquality(Node n, Node e)
-{
- Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
- Assert(e.getType() == n[0].getType().getBagElementType());
-
- Node A = n[0];
- Node B = n[1];
- InferInfo inferInfo;
- inferInfo.d_id = Inference::BAG_EQUALITY;
- inferInfo.d_premises.push_back(n);
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
-
- Node equal = countA.eqNode(countB);
- inferInfo.d_conclusion = equal;
- return inferInfo;
-}
-
struct BagsDeqAttributeId
{
};
typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-InferInfo InferenceGenerator::bagDisequality(Node n)
+InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
{
Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
Assert(n[0][0].getType().isBag());
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DISEQUALITY;
+ inferInfo.d_premises.push_back(reason);
TypeNode elementType = A.getType().getBagElementType();
-
BoundVarManager* bvm = d_nm->getBoundVarManager();
Node element = bvm->mkBoundVar<BagsDeqAttribute>(n, elementType);
- SkolemManager* sm = d_nm->getSkolemManager();
Node skolem =
- sm->mkSkolem(element,
- n,
- "bag_disequal",
- "an extensional lemma for disequality of two bags");
+ d_sm->mkSkolem(element,
+ n,
+ "bag_disequal",
+ "an extensional lemma for disequality of two bags");
- inferInfo.d_newSkolem.push_back(skolem);
-
- Node countA = getMultiplicitySkolem(skolem, A, inferInfo);
- Node countB = getMultiplicitySkolem(skolem, B, inferInfo);
+ Node countA = getMultiplicityTerm(skolem, A);
+ Node countB = getMultiplicityTerm(skolem, B);
Node disEqual = countA.eqNode(countB).notNode();
return inferInfo;
}
+Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
+{
+ Node skolem = d_sm->mkPurifySkolem(n, "skolem_bag", "skolem bag");
+ inferInfo.d_skolems[n] = skolem;
+ return skolem;
+}
+
InferInfo InferenceGenerator::bagEmpty(Node e)
{
EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
Node empty = d_nm->mkConst(emptyBag);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_EMPTY;
- Node count = getMultiplicitySkolem(e, empty, inferInfo);
+ Node count = getMultiplicityTerm(e, empty);
Node equal = count.eqNode(d_zero);
inferInfo.d_conclusion = equal;
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_UNION_DISJOINT;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node sum = d_nm->mkNode(kind::PLUS, countA, countB);
Node equal = count.eqNode(sum);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_UNION_MAX;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node gt = d_nm->mkNode(kind::GT, countA, countB);
Node max = d_nm->mkNode(kind::ITE, gt, countA, countB);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_INTERSECTION_MIN;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node lt = d_nm->mkNode(kind::LT, countA, countB);
Node min = d_nm->mkNode(kind::ITE, lt, countA, countB);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node subtract = d_nm->mkNode(kind::MINUS, countA, countB);
Node gte = d_nm->mkNode(kind::GEQ, countA, countB);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero);
Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node gte = d_nm->mkNode(kind::GEQ, countA, d_one);
Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero);
return inferInfo;
}
-Node InferenceGenerator::getMultiplicitySkolem(Node element,
- Node bag,
- InferInfo& inferInfo)
+Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag)
{
Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag);
- Node skolem = d_state->registerBagElement(count);
- eq::EqualityEngine* ee = d_state->getEqualityEngine();
- ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true));
- inferInfo.d_newSkolem.push_back(skolem);
- return skolem;
+ return count;
}
} // namespace bags
InferenceGenerator(SolverState* state);
/**
- * @param n is (bag x c) of type (Bag E)
+ * @param A is a bag of type (Bag E)
* @param e is a node of type E
* @return an inference that represents the following implication
* (=>
* true
- * (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+ * (>= (bag.count e A) 0)
*/
- InferInfo mkBag(Node n, Node e);
+ InferInfo nonNegativeCount(Node n, Node e);
/**
- * @param n is (= A B) where A, B are bags of type (Bag E)
- * @param e is a node of Type E
+ * @param n is (bag x c) of type (Bag E)
+ * @param e is a node of type E
* @return an inference that represents the following implication
* (=>
- * (= A B)
- * (= (count e A) (count e B)))
+ * true
+ * (= (bag.count e skolem) c))
+ * if e is exactly node x. Node skolem is a fresh variable equals (bag x c).
+ * Otherwise the following inference is returned
+ * (=>
+ * true
+ * (= (bag.count e skolem) (ite (= e x) c 0)))
*/
- InferInfo bagEquality(Node n, Node e);
+ InferInfo mkBag(Node n, Node e);
/**
* @param n is (not (= A B)) where A, B are bags of type (Bag E)
* @return an inference that represents the following implication
* (=>
* (not (= A B))
* (not (= (count e A) (count e B))))
- * where e is a fresh skolem of type E
+ * where e is a fresh skolem of type E.
*/
- InferInfo bagDisequality(Node n);
+ InferInfo bagDisequality(Node n, Node reason);
/**
* @param e is a node of Type E
* @return an inference that represents the following implication
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e k_{(union_disjoint A B)})
+ * (= (count e skolem)
* (+ (count e A) (count e B))))
- * where k_{(union_disjoint A B)} is a unique purification skolem
- * for (union_disjoint A B)
+ * where skolem is a fresh variable equals (union_disjoint A B)
*/
InferInfo unionDisjoint(Node n, Node e);
/**
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (union_max A B))
+ * (=
+ * (count e skolem)
* (ite
- * (> (count e A) (count e B))
- * (count e A)
- * (count e B)))))
+ * (> (count e A) (count e B))
+ * (count e A)
+ * (count e B)))))
+ * where skolem is a fresh variable equals (union_max A B)
*/
InferInfo unionMax(Node n, Node e);
/**
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (intersection_min A B))
+ * (=
+ * (count e skolem)
* (ite(
- * (< (count e A) (count e B))
- * (count e A)
- * (count e B)))))
+ * (< (count e A) (count e B))
+ * (count e A)
+ * (count e B)))))
+ * where skolem is a fresh variable equals (intersection_min A B)
*/
InferInfo intersection(Node n, Node e);
/**
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (difference_subtract A B))
+ * (=
+ * (count e skolem)
* (ite
- * (>= (count e A) (count e B))
- * (- (count e A) (count e B))
- * 0))))
+ * (>= (count e A) (count e B))
+ * (- (count e A) (count e B))
+ * 0))))
+ * where skolem is a fresh variable equals (difference_subtract A B)
*/
InferInfo differenceSubtract(Node n, Node e);
/**
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (difference_remove A B))
+ * (=
+ * (count e skolem)
* (ite
- * (= (count e B) 0)
- * (count e A)
- * 0))))
+ * (= (count e B) 0)
+ * (count e A)
+ * 0))))
+ * where skolem is a fresh variable equals (difference_remove A B)
*/
InferInfo differenceRemove(Node n, Node e);
/**
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (duplicate_removal A))
- * (ite (>= (count e A) 1) 1 0))))
+ * (=
+ * (count e skolem)
+ * (ite (>= (count e A) 1) 1 0))))
+ * where skolem is a fresh variable equals (duplicate_removal A)
*/
InferInfo duplicateRemoval(Node n, Node e);
/**
* @param element of type T
* @param bag of type (bag T)
- * @param inferInfo to store new skolem
- * @return a skolem for (bag.count element bag)
+ * @return a count term (bag.count element bag)
*/
- Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo);
+ Node getMultiplicityTerm(Node element, Node bag);
private:
+ /** generate skolem variable for node n and add it to inferInfo */
+ Node getSkolem(Node& n, InferInfo& inferInfo);
+
NodeManager* d_nm;
SkolemManager* d_sm;
SolverState* d_state;
* process the pending lemmas and then the pending phase requirements.
* Notice that we process the pending lemmas even if there were facts.
*/
- // TODO: refactor this before merge with theory of strings
+ // TODO issue #78: refactor this with theory of strings
void doPending();
private:
{
d_true = NodeManager::currentNM()->mkConst(true);
d_false = NodeManager::currentNM()->mkConst(false);
+ d_nm = NodeManager::currentNM();
}
-struct BagsCountAttributeId
+void SolverState::registerBag(TNode n)
{
-};
-typedef expr::Attribute<BagsCountAttributeId, Node> BagsCountAttribute;
-
-void SolverState::registerClass(TNode n)
-{
- TypeNode t = n.getType();
- if (!t.isBag())
- {
- return;
- }
+ Assert(n.getType().isBag());
d_bags.insert(n);
}
-Node SolverState::registerBagElement(TNode n)
+void SolverState::registerCountTerm(TNode n)
{
Assert(n.getKind() == BAG_COUNT);
- Node element = n[0];
- TypeNode elementType = element.getType();
- Node bag = n[1];
- d_elements[elementType].insert(element);
- NodeManager* nm = NodeManager::currentNM();
- BoundVarManager* bvm = nm->getBoundVarManager();
- Node multiplicity = bvm->mkBoundVar<BagsCountAttribute>(n, nm->integerType());
- Node equal = n.eqNode(multiplicity);
- SkolemManager* sm = nm->getSkolemManager();
- Node skolem = sm->mkSkolem(
- multiplicity,
- equal,
- "bag_multiplicity",
- "an extensional lemma for multiplicity of an element in a bag");
- d_count[bag][element] = skolem;
- Trace("bags::SolverState::registerBagElement")
- << "New skolem: " << skolem << " for " << n << std::endl;
-
- return skolem;
+ Node element = getRepresentative(n[0]);
+ Node bag = getRepresentative(n[1]);
+ d_bagElements[bag].insert(element);
+}
+
+const std::set<Node>& SolverState::getBags() { return d_bags; }
+
+const std::set<Node>& SolverState::getElements(Node B)
+{
+ Node bag = getRepresentative(B);
+ return d_bagElements[B];
+}
+
+void SolverState::reset()
+{
+ d_bagElements.clear();
+ d_bags.clear();
}
-std::set<Node>& SolverState::getBags() { return d_bags; }
+void SolverState::initialize()
+{
+ reset();
+ collectBagsAndCountTerms();
+}
+
+void SolverState::collectBagsAndCountTerms()
+{
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "SolverState::collectBagsAndCountTerms start" << endl;
+ eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
+ while (!repIt.isFinished())
+ {
+ Node eqc = (*repIt);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "[" << eqc << "]: " << endl;
+
+ if (eqc.getType().isBag())
+ {
+ registerBag(eqc);
+ }
-std::set<Node>& SolverState::getElements(TypeNode t) { return d_elements[t]; }
+ eq::EqClassIterator it = eq::EqClassIterator(eqc, d_ee);
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ Kind k = n.getKind();
+ if (k == MK_BAG)
+ {
+ // for terms (bag x c) we need to store x by registering the count term
+ // (bag.count x (bag x c))
+ Node count = d_nm->mkNode(BAG_COUNT, n[0], n);
+ registerCountTerm(count);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "registered " << count << endl;
+ }
+ if (k == BAG_COUNT)
+ {
+ // this takes care of all count terms in each equivalent class
+ registerCountTerm(n);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "registered " << n << endl;
+ }
+ ++it;
+ }
-std::map<Node, Node>& SolverState::getBagElements(Node B) { return d_count[B]; }
+ ++repIt;
+ }
+
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "SolverState::collectBagsAndCountTerms end" << endl;
+}
} // namespace bags
} // namespace theory
public:
SolverState(context::Context* c, context::UserContext* u, Valuation val);
- void registerClass(TNode n);
+ /**
+ * This function adds the bag representative n to the set d_bags if it is not
+ * already there. This function is called during postCheck to collect bag
+ * terms in the equality engine. See the documentation of
+ * @link SolverState::collectBagsAndCountTerms
+ */
+ void registerBag(TNode n);
- Node registerBagElement(TNode n);
-
- std::set<Node>& getBags();
-
- std::set<Node>& getElements(TypeNode t);
-
- std::map<Node, Node>& getBagElements(Node B);
+ /**
+ * @param n has the form (bag.count e A)
+ * @pre bag A needs is already registered using registerBag(A)
+ * @return a unique skolem for (bag.count e A)
+ */
+ void registerCountTerm(TNode n);
+ /** get all bag terms that are representatives in the equality engine.
+ * This function is valid after the current solver is initialized during
+ * postCheck. See SolverState::initialize and BagSolver::postCheck
+ */
+ const std::set<Node>& getBags();
+ /**
+ * @pre B is a registered bag
+ * @return all elements associated with bag B so far
+ * Note that associated elements are not necessarily elements in B
+ * Example:
+ * (assert (= 0 (bag.count x B)))
+ * element x is associated with bag B, albeit x is definitely not in B.
+ */
+ const std::set<Node>& getElements(Node B);
+ /** initialize bag and count terms */
+ void initialize();
private:
+ /** clear all bags data structures */
+ void reset();
+ /** collect bags' representatives and all count terms.
+ * This function is called during postCheck */
+ void collectBagsAndCountTerms();
/** constants */
Node d_true;
Node d_false;
+ /** node manager for this solver state */
+ NodeManager* d_nm;
+ /** collection of bag representatives */
std::set<Node> d_bags;
- std::map<TypeNode, std::set<Node>> d_elements;
- /** bag -> element -> multiplicity */
- std::map<Node, std::map<Node, Node>> d_count;
+ /** bag -> associated elements */
+ std::map<Node, std::set<Node>> d_bagElements;
}; /* class SolverState */
} // namespace bags
void TheoryBags::postCheck(Effort effort)
{
d_im.doPendingFacts();
- // TODO: clean this before merge Assert(d_strat.isStrategyInit());
+ // TODO issue #78: add Assert(d_strat.isStrategyInit());
if (!d_state.isInConflict() && !d_valuation.needCheck())
- // TODO: clean this before merge && d_strat.hasStrategyEffort(e))
+ // TODO issue #78: add && d_strat.hasStrategyEffort(e))
{
Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl;
- // TODO: clean this before merge ++(d_statistics.d_checkRuns);
+ // TODO issue #78: add ++(d_statistics.d_checkRuns);
bool sentLemma = false;
bool hadPending = false;
Trace("bags-check") << "Full effort check..." << std::endl;
do
{
d_im.reset();
- // TODO: clean this before merge ++(d_statistics.d_strategyRuns);
+ // TODO issue #78: add ++(d_statistics.d_strategyRuns);
Trace("bags-check") << " * Run strategy..." << std::endl;
- // TODO: clean this before merge runStrategy(e);
+ // TODO issue #78: add runStrategy(e);
d_solver.postCheck();
continue;
}
Node r = d_state.getRepresentative(n);
- std::map<Node, Node> elements = d_state.getBagElements(r);
+ std::set<Node> solverElements = d_state.getElements(r);
+ std::set<Node> elements;
+ // only consider terms in termSet and ignore other elements in the solver
+ std::set_intersection(termSet.begin(),
+ termSet.end(),
+ solverElements.begin(),
+ solverElements.end(),
+ std::inserter(elements, elements.begin()));
Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl
<< elements << std::endl;
std::map<Node, Node> elementReps;
- for (std::pair<Node, Node> pair : elements)
+ for (const Node& e : elements)
{
- Node key = d_state.getRepresentative(pair.first);
- Node value = d_state.getRepresentative(pair.second);
+ Node key = d_state.getRepresentative(e);
+ Node countTerm = NodeManager::currentNM()->mkNode(BAG_COUNT, e, r);
+ Node value = d_state.getRepresentative(countTerm);
elementReps[key] = value;
}
Node rep = NormalForm::constructBagFromElements(tn, elementReps);
/**************************** eq::NotifyClass *****************************/
-void TheoryBags::eqNotifyNewClass(TNode n)
-{
- Kind k = n.getKind();
- d_state.registerClass(n);
- if (n.getKind() == MK_BAG)
- {
- // TODO: refactor this before merge
- /*
- * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m))
- * where s is a fresh skolem variable
- */
- NodeManager* nm = NodeManager::currentNM();
- Node count = nm->mkNode(BAG_COUNT, n[0], n);
- Node skolem = d_state.registerBagElement(count);
- Node countSkolem = count.eqNode(skolem);
- Node skolemMultiplicity = n[1].eqNode(skolem);
- Node lemma = countSkolem.andNode(skolemMultiplicity);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
- }
- if (k == BAG_COUNT)
- {
- /*
- * (count x A) generates the lemma (= s (count x A))
- * where s is a fresh skolem variable
- */
- Node skolem = d_state.registerBagElement(n);
- Node lemma = n.eqNode(skolem);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
- }
-}
+void TheoryBags::eqNotifyNewClass(TNode n) {}
void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
TypeNode t1 = n1.getType();
if (t1.isBag())
{
- InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode());
- Node lemma = reason.impNode(info.d_conclusion);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
+ InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
+ info.process(d_inferManager, true);
}
}
regress1/bug681.smt2
regress1/bug694-Unapply1.scala-0.smt2
regress1/bug800.smt2
+ regress1/bags/difference_remove1.smt2
regress1/bags/disequality.smt2
+ regress1/bags/issue5759.smt2
regress1/bags/subbag1.smt2
regress1/bags/subbag2.smt2
regress1/bags/union_disjoint.smt2
--- /dev/null
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= A (union_disjoint B (bag y 2))))
+(assert (= x y))
+(check-sat)
--- /dev/null
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(assert (not (= A (union_max (bag x 1) (bag 0 1)))))
+(assert (= A (union_disjoint B (bag 0 1))))
+(assert (= x 1))
+(check-sat)
\ No newline at end of file