d_undoStackIndex = d_undoStackIndex + 1;
}
+void InequalityGraph::initializeModelValue(TNode node) {
+ TermId id = getTermId(node);
+ Assert (!hasModelValue(id));
+ bool isConst = node.getKind() == kind::CONST_BITVECTOR;
+ unsigned size = utils::getSize(node);
+ BitVector value = isConst? node.getConst<BitVector>() : BitVector(size, 0u);
+ setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
+}
+
+bool InequalityGraph::isRegistered(TNode term) const {
+ return d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end();
+}
+
TermId InequalityGraph::registerTerm(TNode term) {
if (d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end()) {
- return d_termNodeToIdMap[term];
+ TermId id = d_termNodeToIdMap[term];
+ if (!hasModelValue(id)) {
+ // we could have backtracked and
+ initializeModelValue(term);
+ }
+ return id;
}
// store in node mapping
// create InequalityNode
unsigned size = utils::getSize(term);
+
bool isConst = term.getKind() == kind::CONST_BITVECTOR;
- BitVector value = isConst? term.getConst<BitVector>() : BitVector(size, 0u);
-
InequalityNode ineq = InequalityNode(id, size, isConst);
- setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
-
+
Assert (d_ineqNodes.size() == id);
d_ineqNodes.push_back(ineq);
Assert (d_ineqEdges.size() == id);
d_ineqEdges.push_back(Edges());
- // add the default edges min <= term <= max
- // addEdge(getMinValueId(size), id, false, AxiomReasonId);
- // addEdge(id, getMaxValueId(size), false, AxiomReasonId);
+ initializeModelValue(term);
return id;
}
return d_termNodes[id];
}
+TermId InequalityGraph::getTermId(TNode node) const {
+ Assert (d_termNodeToIdMap.find(node) != d_termNodeToIdMap.end());
+ return d_termNodeToIdMap.find(node)->second;
+}
+
void InequalityGraph::setConflict(const std::vector<ReasonId>& conflict) {
Assert (!d_inConflict);
d_inConflict = true;
}
BitVector InequalityGraph::getValue(TermId id) const {
- Assert (hasModelValue(id));
- BitVector res = (*(d_modelValues.find(id))).second.value;
+ Assert (hasModelValue(id));
+ BitVector res = (*(d_modelValues.find(id))).second.value;
return res;
}
return mv.reason != UndefinedReasonId;
}
+bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) {
+ Debug("bv-inequality") << "InequalityGraph::addDisequality " << reason << "\n";
+ d_disequalities.push_back(reason);
+
+ if (!isRegistered(a) || !isRegistered(b)) {
+ splitDisequality(reason);
+ return true;
+ }
+ TermId id_a = getTermId(a);
+ TermId id_b = getTermId(b);
+ if (!hasModelValue(id_a)) {
+ initializeModelValue(a);
+ }
+ if (!hasModelValue(id_b)) {
+ initializeModelValue(b);
+ }
+ const BitVector& val_a = getValue(id_a);
+ const BitVector& val_b = getValue(id_b);
+ if (val_a == val_b) {
+ if (a.getKind() == kind::CONST_BITVECTOR) {
+ // then we know b cannot be smaller than the assigned value so we try to make it larger
+ return addInequality(a, b, true, reason);
+ }
+ if (b.getKind() == kind::CONST_BITVECTOR) {
+ return addInequality(b, a, true, reason);
+ }
+ // if none of the terms are constants just add the lemma
+ splitDisequality(reason);
+ } else {
+ Debug("bv-inequality-internal") << "Disequal: " << a << " => " << val_a.toString(10) << "\n"
+ << " " << b << " => " << val_b.toString(10) << "\n";
+ }
+ return true;
+}
+
+void InequalityGraph::splitDisequality(TNode diseq) {
+ Debug("bv-inequality-internal")<<"InequalityGraph::splitDisequality " << diseq <<"\n";
+ Assert (diseq.getKind() == kind::NOT && diseq[0].getKind() == kind::EQUAL);
+ TNode a = diseq[0][0];
+ TNode b = diseq[0][1];
+ Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
+ Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
+ Node split = utils::mkNode(kind::OR, a_lt_b, b_lt_a);
+ Node lemma = utils::mkNode(kind::IMPLIES, diseq, split);
+ if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
+ d_lemmaQueue.push_back(lemma);
+ }
+}
+
+void InequalityGraph::getNewLemmas(std::vector<TNode>& new_lemmas) {
+ for (unsigned i = d_lemmaIndex; i < d_lemmaQueue.size(); ++i) {
+ TNode lemma = d_lemmaQueue[i];
+ if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
+ new_lemmas.push_back(lemma);
+ d_lemmasAdded.insert(lemma);
+ }
+ d_lemmaIndex = d_lemmaIndex + 1;
+ }
+}
+
std::string InequalityGraph::PQueueElement::toString() const {
ostringstream os;
os << "(id: " << id << ", lower_bound: " << lower_bound.toString(10) <<", old_value: " << model_value.value.toString(10) << ")";
typedef __gnu_cxx::hash_set<TermId> TermIdSet;
typedef std::priority_queue<PQueueElement> BFSQueue;
-
+ typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
std::vector<InequalityNode> d_ineqNodes;
std::vector< Edges > d_ineqEdges;
std::vector<TNode> d_conflict;
bool d_signed;
- context::CDHashMap<TermId, ModelValue> d_modelValues;
+ context::CDHashMap<TermId, ModelValue> d_modelValues;
+ void initializeModelValue(TNode node);
void setModelValue(TermId term, const ModelValue& mv);
ModelValue getModelValue(TermId term) const;
bool hasModelValue(TermId id) const;
TermId registerTerm(TNode term);
TNode getTermNode(TermId id) const;
TermId getTermId(TNode node) const;
-
+ bool isRegistered(TNode term) const;
+
ReasonId registerReason(TNode reason);
TNode getReasonNode(ReasonId id) const;
const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; }
unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); }
bool isConst(TermId id) const { return getInequalityNode(id).isConstant(); }
- // BitVector maxValue(unsigned bitwidth);
- // BitVector minValue(unsigned bitwidth);
- // TermId getMaxValueId(unsigned bitwidth);
- // TermId getMinValueId(unsigned bitwidth);
BitVector getValue(TermId id) const;
* @param explanation
*/
void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation);
+ void splitDisequality(TNode diseq);
+ /**
+ Disequality reasoning
+ */
+
+ /*** The currently asserted disequalities */
+ context::CDQueue<TNode> d_disequalities;
+ context::CDQueue<Node> d_lemmaQueue;
+ context::CDO<unsigned> d_lemmaIndex;
+ TNodeSet d_lemmasAdded;
+
/** Backtracking mechanisms **/
std::vector<std::pair<TermId, InequalityEdge> > d_undoStack;
context::CDO<unsigned> d_undoStackIndex;
d_conflict(),
d_signed(s),
d_modelValues(c),
+ d_disequalities(c),
+ d_lemmaQueue(c),
+ d_lemmaIndex(c, 0),
+ d_lemmasAdded(),
d_undoStack(),
d_undoStackIndex(c)
{}
* @return
*/
bool addInequality(TNode a, TNode b, bool strict, TNode reason);
+ bool addDisequality(TNode a, TNode b, TNode reason);
bool areLessThan(TNode a, TNode b);
void getConflict(std::vector<TNode>& conflict);
virtual ~InequalityGraph() throw(AssertionException) {}
+ void getNewLemmas(std::vector<TNode>& new_lemmas);
};
}