From: lianah Date: Mon, 25 Mar 2013 22:24:29 +0000 (-0400) Subject: getEqualityStatus now also queries the inequality solver X-Git-Tag: cvc5-1.0.0~7361^2~17 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=7f9b419adf3e45ce12ab9fb9b2d1afa076110e7d;p=cvc5.git getEqualityStatus now also queries the inequality solver --- diff --git a/src/theory/bv/bv_inequality_graph.cpp b/src/theory/bv/bv_inequality_graph.cpp index 4bd315872..a1d2efbb5 100644 --- a/src/theory/bv/bv_inequality_graph.cpp +++ b/src/theory/bv/bv_inequality_graph.cpp @@ -29,55 +29,6 @@ const ReasonId CVC4::theory::bv::UndefinedReasonId = -1; const ReasonId CVC4::theory::bv::AxiomReasonId = -2; -// BitVector InequalityGraph::maxValue(unsigned bitwidth) { -// if (d_signed) { -// return BitVector(1, 0u).concat(~BitVector(bitwidth - 1, 0u)); -// } -// return ~BitVector(bitwidth, 0u); -// } - -// BitVector InequalityGraph::minValue(unsigned bitwidth) { -// if (d_signed) { -// return ~BitVector(bitwidth, 0u); -// } -// return BitVector(bitwidth, 0u); -// } - -// TermId InequalityGraph::getMaxValueId(unsigned bitwidth) { -// BitVector bv = maxValue(bitwidth); -// Node max = utils::mkConst(bv); - -// if (d_termNodeToIdMap.find(max) == d_termNodeToIdMap.end()) { -// TermId id = d_termNodes.size(); -// d_termNodes.push_back(max); -// d_termNodeToIdMap[max] = id; -// InequalityNode node(id, bitwidth, true, bv); -// d_ineqNodes.push_back(node); - -// // although it will never have out edges we need this to keep the size of -// // d_termNodes and d_ineqEdges in sync -// d_ineqEdges.push_back(Edges()); -// return id; -// } -// return d_termNodeToIdMap[max]; -// } - -// TermId InequalityGraph::getMinValueId(unsigned bitwidth) { -// BitVector bv = minValue(bitwidth); -// Node min = utils::mkConst(bv); - -// if (d_termNodeToIdMap.find(min) == d_termNodeToIdMap.end()) { -// TermId id = d_termNodes.size(); -// d_termNodes.push_back(min); -// d_termNodeToIdMap[min] = id; -// d_ineqEdges.push_back(Edges()); -// InequalityNode node = InequalityNode(id, bitwidth, true, bv); -// d_ineqNodes.push_back(node); -// return id; -// } -// return d_termNodeToIdMap[min]; -// } - bool InequalityGraph::addInequality(TNode a, TNode b, bool strict, TNode reason) { Debug("bv-inequality") << "InequlityGraph::addInequality " << a << " " << b << " strict: " << strict << "\n"; @@ -121,24 +72,21 @@ bool InequalityGraph::addInequality(TNode a, TNode b, bool strict, TNode reason) // add the inequality edge addEdge(id_a, id_b, strict, id_reason); - BFSQueue queue; - ModelValue mv = hasModelValue(id_a) ? getModelValue(id_a) : ModelValue(); - queue.push(PQueueElement(id_a, getValue(id_a), mv)); - TermIdSet seen; - return computeValuesBFS(queue, id_a, seen); + BFSQueue queue(&d_modelValues); + Assert (hasModelValue(id_a)); + queue.push(id_a); + return processQueue(queue, id_a); } -bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen, bool& changed) { - TermId id = el.id; - const BitVector& lower_bound = el.lower_bound; - InequalityNode& ineqNode = getInequalityNode(id); - - if (ineqNode.isConstant()) { +bool InequalityGraph::updateValue(TermId id, ModelValue new_mv, TermId start, bool& changed) { + BitVector lower_bound = new_mv.value; + + if (isConst(id)) { if (getValue(id) < lower_bound) { Debug("bv-inequality") << "Conflict: constant " << getValue(id) << "\n"; std::vector conflict; - TermId parent = el.model_value.parent; - ReasonId reason = el.model_value.reason; + TermId parent = new_mv.parent; + ReasonId reason = new_mv.reason; conflict.push_back(reason); computeExplanation(UndefinedTermId, parent, conflict); Debug("bv-inequality") << "InequalityGraph::addInequality conflict: constant\n"; @@ -146,12 +94,12 @@ bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const T return false; } } else { - // if not constant we can update the value + // if not constant we can try to update the value if (getValue(id) < lower_bound) { // if we are updating the term we started with we must be in a cycle - if (seen.count(id) && id == start) { - TermId parent = el.model_value.parent; - ReasonId reason = el.model_value.reason; + if (id == start) { + TermId parent = new_mv.parent; + ReasonId reason = new_mv.reason; std::vector conflict; conflict.push_back(reason); computeExplanation(id, parent, conflict); @@ -163,68 +111,66 @@ bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const T << " from " << getValue(id) << "\n" << " to " << lower_bound << "\n"; changed = true; - ModelValue mv = el.model_value; - mv.value = lower_bound; - setModelValue(id, mv); + setModelValue(id, new_mv); } } return true; } -bool InequalityGraph::computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen) { - if (queue.empty()) - return true; - - const PQueueElement current = queue.top(); - queue.pop(); - Debug("bv-inequality-internal") << "InequalityGraph::computeValuesBFS proceessing " << getTermNode(current.id) << " " << current.toString() << "\n"; - bool updated_current = false; - if (!updateValue(current, start, seen, updated_current)) { - return false; - } - if (seen.count(current.id) && current.id == start) { - // we know what we didn't update start or we would have had a conflict - Debug("bv-inequality-internal") << "InequalityGraph::computeValuesBFS equal cycle."; - // this means we are in a cycle where all the values are forced to be equal - // TODO: make sure we collapse this cycle into one big node. - return computeValuesBFS(queue, start, seen); - } - - if (!updated_current && !(seen.count(current.id) == 0 && current.id == start)) { - // if we didn't update current we don't need to readd to the queue it's children - seen.insert(current.id); - Debug("bv-inequality-internal") << " unchanged " << getTermNode(current.id) << "\n"; - return computeValuesBFS(queue, start, seen); - } +bool InequalityGraph::processQueue(BFSQueue& queue, TermId start) { + while (!queue.empty()) { + TermId current = queue.top(); + queue.pop(); + Debug("bv-inequality-internal") << "InequalityGraph::processQueue proceessing " << getTermNode(current) << "\n"; - seen.insert(current.id); + BitVector current_value = getValue(current); - const BitVector& current_value = getValue(current.id); + unsigned size = getBitwidth(current); + const BitVector zero(size, 0u); + const BitVector one(size, 1u); - unsigned size = getBitwidth(current.id); - const BitVector zero(size, 0u); - const BitVector one(size, 1u); - - const Edges& edges = getEdges(current.id); - for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) { - TermId next = it->next; - const BitVector increment = it->strict ? one : zero; - const BitVector& next_lower_bound = current_value + increment; - if (next_lower_bound < current_value) { - // it means we have an overflow and hence a conflict + const Edges& edges = getEdges(current); + for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) { + TermId next = it->next; + ReasonId reason = it->reason; + + const BitVector increment = it->strict ? one : zero; + const BitVector next_lower_bound = current_value + increment; + + if (next_lower_bound < current_value) { + // it means we have an overflow and hence a conflict std::vector conflict; conflict.push_back(it->reason); - computeExplanation(start, current.id, conflict); + computeExplanation(start, current, conflict); Debug("bv-inequality") << "InequalityGraph::addInequality conflict: cycle \n"; setConflict(conflict); return false; + } + + ModelValue new_mv(next_lower_bound, current, reason); + bool updated = false; + if (!updateValue(next, new_mv, start, updated)) { + return false; + } + + if (next == start) { + // we know what we didn't update start or we would have had a conflict + // this means we are in a cycle where all the values are forced to be equal + Debug("bv-inequality-internal") << "InequalityGraph::processQueue equal cycle."; + continue; + } + + if (!updated) { + // if we didn't update current we don't need to add to the queue it's children + Debug("bv-inequality-internal") << " unchanged " << getTermNode(next) << "\n"; + continue; + } + + queue.push(next); + Debug("bv-inequality-internal") << " enqueue " << getTermNode(next) << "\n"; } - const BitVector& value = getValue(next); - PQueueElement el = PQueueElement(next, next_lower_bound, ModelValue(value, current.id, it->reason)); - queue.push(el); - Debug("bv-inequality-internal") << " enqueue " << getTermNode(el.id) << " " << el.toString() << "\n"; } - return computeValuesBFS(queue, start, seen); + return true; } void InequalityGraph::computeExplanation(TermId from, TermId to, std::vector& explanation) { @@ -371,8 +317,7 @@ bool InequalityGraph::hasModelValue(TermId id) const { BitVector InequalityGraph::getValue(TermId id) const { Assert (hasModelValue(id)); - BitVector res = (*(d_modelValues.find(id))).second.value; - return res; + return (*(d_modelValues.find(id))).second.value; } bool InequalityGraph::hasReason(TermId id) const { @@ -396,12 +341,21 @@ bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) { if (!hasModelValue(id_b)) { initializeModelValue(b); } - const BitVector& val_a = getValue(id_a); - const BitVector& val_b = getValue(id_b); + const BitVector val_a = getValue(id_a); + const BitVector val_b = getValue(id_b); if (val_a == val_b) { if (a.getKind() == kind::CONST_BITVECTOR) { // then we know b cannot be smaller than the assigned value so we try to make it larger - return addInequality(a, b, true, reason); + std::vector explanation_ids; + computeExplanation(UndefinedTermId, id_b, explanation_ids); + std::vector explanation_nodes; + explanation_nodes.push_back(reason); + for (unsigned i = 0; i < explanation_ids.size(); ++i) { + explanation_nodes.push_back(getReasonNode(explanation_ids[i])); + } + Node explanation = utils::mkAnd(explanation_nodes); + d_reasonSet.insert(explanation); + return addInequality(a, b, true, explanation); } if (b.getKind() == kind::CONST_BITVECTOR) { return addInequality(b, a, true, reason); @@ -418,32 +372,26 @@ bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) { void InequalityGraph::splitDisequality(TNode diseq) { Debug("bv-inequality-internal")<<"InequalityGraph::splitDisequality " << diseq <<"\n"; Assert (diseq.getKind() == kind::NOT && diseq[0].getKind() == kind::EQUAL); - TNode a = diseq[0][0]; - TNode b = diseq[0][1]; - Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b); - Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a); - Node split = utils::mkNode(kind::OR, a_lt_b, b_lt_a); - Node lemma = utils::mkNode(kind::IMPLIES, diseq, split); - if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) { - d_lemmaQueue.push_back(lemma); + if (d_disequalitiesAlreadySplit.find(diseq) == d_disequalitiesAlreadySplit.end()) { + d_disequalitiesToSplit.push_back(diseq); } } -void InequalityGraph::getNewLemmas(std::vector& new_lemmas) { - for (unsigned i = d_lemmaIndex; i < d_lemmaQueue.size(); ++i) { - TNode lemma = d_lemmaQueue[i]; - if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) { - new_lemmas.push_back(lemma); - d_lemmasAdded.insert(lemma); +void InequalityGraph::getNewLemmas(std::vector& new_lemmas) { + for (unsigned i = d_diseqToSplitIndex; i < d_disequalitiesToSplit.size(); ++i) { + TNode diseq = d_disequalitiesToSplit[i]; + if (d_disequalitiesAlreadySplit.find(diseq) == d_disequalitiesAlreadySplit.end()) { + TNode a = diseq[0][0]; + TNode b = diseq[0][1]; + Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b); + Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a); + Node eq = diseq[0]; + Node lemma = utils::mkNode(kind::OR, a_lt_b, b_lt_a, eq); + new_lemmas.push_back(lemma); + d_disequalitiesAlreadySplit.insert(diseq); } - d_lemmaIndex = d_lemmaIndex + 1; - } -} - -std::string InequalityGraph::PQueueElement::toString() const { - ostringstream os; - os << "(id: " << id << ", lower_bound: " << lower_bound.toString(10) <<", old_value: " << model_value.value.toString(10) << ")"; - return os.str(); + d_diseqToSplitIndex = d_diseqToSplitIndex + 1; + } } void InequalityGraph::backtrack() { @@ -467,3 +415,37 @@ void InequalityGraph::backtrack() { edges.pop_back(); } } + +void InequalityGraph::checkDisequalities() { + for (CDQueue::const_iterator it = d_disequalities.begin(); it != d_disequalities.end(); ++it) { + if (d_disequalitiesAlreadySplit.find(*it) == d_disequalitiesAlreadySplit.end()) { + // if we haven't already split on this disequality + TNode diseq = *it; + TermId a_id = registerTerm(diseq[0][0]); + TermId b_id = registerTerm(diseq[0][1]); + if (getValue(a_id) == getValue(b_id)) { + // if the disequality is not satisified by the model + d_disequalitiesToSplit.push_back(diseq); + } + } + } +} + +bool InequalityGraph::isLessThan(TNode a, TNode b) { + Assert (isRegistered(a) && isRegistered(b)); + Unimplemented(); +} + +bool InequalityGraph::hasValueInModel(TNode node) const { + if (isRegistered(node)) { + TermId id = getTermId(node); + return hasModelValue(id); + } + return false; +} + +BitVector InequalityGraph::getValueInModel(TNode node) const { + TermId id = getTermId(node); + Assert (hasModelValue(id)); + return getValue(id); +} diff --git a/src/theory/bv/bv_inequality_graph.h b/src/theory/bv/bv_inequality_graph.h index 1335eff93..b23ea7704 100644 --- a/src/theory/bv/bv_inequality_graph.h +++ b/src/theory/bv/bv_inequality_graph.h @@ -87,34 +87,37 @@ class InequalityGraph : public context::ContextNotifyObj{ value(val) {} }; + + typedef context::CDHashMap Model; - struct PQueueElement { - TermId id; - BitVector lower_bound; - ModelValue model_value; - PQueueElement(TermId id, const BitVector& lb, const ModelValue& mv) - : id(id), - lower_bound(lb), - model_value(mv) + struct QueueComparator { + const Model* d_model; + QueueComparator(const Model* model) + : d_model(model) {} - - bool operator< (const PQueueElement& other) const { - return model_value.value > other.model_value.value; + bool operator() (TermId left, TermId right) const { + Assert (d_model->find(left) != d_model->end() && + d_model->find(right) != d_model->end()); + + return (*(d_model->find(left))).second.value < (*(d_model->find(right))).second.value; } - std::string toString() const; - }; - + }; + typedef __gnu_cxx::hash_map ReasonToIdMap; typedef __gnu_cxx::hash_map TermNodeToIdMap; typedef std::vector Edges; typedef __gnu_cxx::hash_set TermIdSet; - typedef std::priority_queue BFSQueue; - typedef __gnu_cxx::hash_set TNodeSet; + typedef std::priority_queue, QueueComparator> BFSQueue; + typedef __gnu_cxx::hash_set TNodeSet; + typedef __gnu_cxx::hash_set NodeSet; + std::vector d_ineqNodes; std::vector< Edges > d_ineqEdges; - + + // to keep the explanation nodes alive + NodeSet d_reasonSet; std::vector d_reasonNodes; ReasonToIdMap d_reasonToIdMap; @@ -125,7 +128,7 @@ class InequalityGraph : public context::ContextNotifyObj{ std::vector d_conflict; bool d_signed; - context::CDHashMap d_modelValues; + Model d_modelValues; void initializeModelValue(TNode node); void setModelValue(TermId term, const ModelValue& mv); ModelValue getModelValue(TermId term) const; @@ -163,23 +166,21 @@ class InequalityGraph : public context::ContextNotifyObj{ /** * If necessary update the value in the model of the current queue element. * - * @param el current queue element we are updating + * @param id current queue element we are updating * @param start node we started with, to detect cycles - * @param seen * * @return */ - bool updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen, bool& changed); + bool updateValue(TermId id, ModelValue new_mv, TermId start, bool& changed); /** * Update the current model starting with the start term. * * @param queue * @param start - * @param seen * * @return */ - bool computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen); + bool processQueue(BFSQueue& queue, TermId start); /** * Return the reasons why from <= to. If from is undefined we just * explain the current value of to. @@ -197,9 +198,10 @@ class InequalityGraph : public context::ContextNotifyObj{ /*** The currently asserted disequalities */ context::CDQueue d_disequalities; - context::CDQueue d_lemmaQueue; - context::CDO d_lemmaIndex; - TNodeSet d_lemmasAdded; + context::CDQueue d_disequalitiesToSplit; + context::CDO d_diseqToSplitIndex; + TNodeSet d_lemmasAdded; + TNodeSet d_disequalitiesAlreadySplit; /** Backtracking mechanisms **/ std::vector > d_undoStack; @@ -223,28 +225,72 @@ public: d_signed(s), d_modelValues(c), d_disequalities(c), - d_lemmaQueue(c), - d_lemmaIndex(c, 0), - d_lemmasAdded(), + d_disequalitiesToSplit(c), + d_diseqToSplitIndex(c, 0), + d_disequalitiesAlreadySplit(), d_undoStack(), d_undoStackIndex(c) {} /** - * + * Add a new inequality to the graph * * @param a * @param b - * @param diff + * @param strict * @param reason * * @return */ bool addInequality(TNode a, TNode b, bool strict, TNode reason); + /** + * Add a new disequality to the graph. This may lead in a lemma. + * + * @param a + * @param b + * @param reason + * + * @return + */ bool addDisequality(TNode a, TNode b, TNode reason); - bool areLessThan(TNode a, TNode b); void getConflict(std::vector& conflict); virtual ~InequalityGraph() throw(AssertionException) {} - void getNewLemmas(std::vector& new_lemmas); + /** + * Get any new lemmas (resulting from disequalities splits) that need + * to be added. + * + * @param new_lemmas + */ + void getNewLemmas(std::vector& new_lemmas); + /** + * Check that the currently asserted disequalities that have not been split on + * are still true in the current model. + */ + void checkDisequalities(); + /** + * Return true if a < b is entailed by the current set of assertions. + * + * @param a + * @param b + * + * @return + */ + bool isLessThan(TNode a, TNode b); + /** + * Returns true if the term has a value in the model (i.e. if we have seen it) + * + * @param a + * + * @return + */ + bool hasValueInModel(TNode a) const; + /** + * Return the value of a in the current model. + * + * @param a + * + * @return + */ + BitVector getValueInModel(TNode a) const; }; } diff --git a/src/theory/bv/bv_subtheory.h b/src/theory/bv/bv_subtheory.h index c442fa6dd..00b3526c0 100644 --- a/src/theory/bv/bv_subtheory.h +++ b/src/theory/bv/bv_subtheory.h @@ -91,6 +91,9 @@ public: virtual void preRegister(TNode node) {} virtual void propagate(Theory::Effort e) {} virtual void collectModelInfo(TheoryModel* m) = 0; + virtual bool isComplete() = 0; + virtual EqualityStatus getEqualityStatus(TNode a, TNode b) = 0; + bool done() { return d_assertionQueue.size() == d_assertionIndex; } TNode get() { Assert (!done()); @@ -98,8 +101,7 @@ public: d_assertionIndex = d_assertionIndex + 1; return res; } - void assertFact(TNode fact) { d_assertionQueue.push_back(fact); } - + virtual void assertFact(TNode fact) { d_assertionQueue.push_back(fact); } }; } diff --git a/src/theory/bv/bv_subtheory_bitblast.cpp b/src/theory/bv/bv_subtheory_bitblast.cpp index 2f76e32d3..20da2511c 100644 --- a/src/theory/bv/bv_subtheory_bitblast.cpp +++ b/src/theory/bv/bv_subtheory_bitblast.cpp @@ -74,7 +74,7 @@ bool BitblastSolver::check(Theory::Effort e) { d_bitblastQueue.pop(); } - // Processinga ssertions + // Processing assertions while (!done()) { TNode fact = get(); if (!d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_BITBLAST)) { diff --git a/src/theory/bv/bv_subtheory_bitblast.h b/src/theory/bv/bv_subtheory_bitblast.h index 318fdd230..47bed07dd 100644 --- a/src/theory/bv/bv_subtheory_bitblast.h +++ b/src/theory/bv/bv_subtheory_bitblast.h @@ -46,6 +46,7 @@ public: void explain(TNode literal, std::vector& assumptions); EqualityStatus getEqualityStatus(TNode a, TNode b); void collectModelInfo(TheoryModel* m); + bool isComplete() { return true; } }; } diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index 868f3754f..5eb37b50a 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -71,7 +71,7 @@ class CoreSolver : public SubtheorySolver { public: CoreSolver(context::Context* c, TheoryBV* bv); ~CoreSolver(); - bool isCoreTheory() { return d_isCoreTheory; } + bool isComplete() { return d_isCoreTheory; } void setMasterEqualityEngine(eq::EqualityEngine* eq); void preRegister(TNode node); bool check(Theory::Effort e); diff --git a/src/theory/bv/bv_subtheory_inequality.cpp b/src/theory/bv/bv_subtheory_inequality.cpp index 6b9842e8f..338026681 100644 --- a/src/theory/bv/bv_subtheory_inequality.cpp +++ b/src/theory/bv/bv_subtheory_inequality.cpp @@ -61,14 +61,20 @@ bool InequalitySolver::check(Theory::Effort e) { ok = d_inequalityGraph.addInequality(a, b, false, fact); } } + if (!ok) { std::vector conflict; d_inequalityGraph.getConflict(conflict); - d_bv->setConflict(utils::mkConjunction(conflict)); + d_bv->setConflict(utils::flattenAnd(conflict)); return false; } + + // make sure all the disequalities we didn't split on are still satisifed + // and split on the ones that are not + d_inequalityGraph.checkDisequalities(); + // send out any lemmas - std::vector lemmas; + std::vector lemmas; d_inequalityGraph.getNewLemmas(lemmas); for(unsigned i = 0; i < lemmas.size(); ++i) { d_bv->lemma(lemmas[i]); @@ -76,6 +82,38 @@ bool InequalitySolver::check(Theory::Effort e) { return true; } +EqualityStatus InequalitySolver::getEqualityStatus(TNode a, TNode b) { + Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b); + Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a); + + // if an inequality containing the terms has been asserted then we know + // the equality is false + if (d_assertionSet.contains(a_lt_b) || d_assertionSet.contains(b_lt_a)) { + return EQUALITY_FALSE; + } + + if (!d_inequalityGraph.hasValueInModel(a) || + !d_inequalityGraph.hasValueInModel(b)) { + return EQUALITY_UNKNOWN; + } + + // TODO: check if this disequality is entailed by inequalities via transitivity + + BitVector a_val = d_inequalityGraph.getValueInModel(a); + BitVector b_val = d_inequalityGraph.getValueInModel(b); + + if (a_val == b_val) { + return EQUALITY_TRUE_IN_MODEL; + } else { + return EQUALITY_FALSE_IN_MODEL; + } +} + +void InequalitySolver::assertFact(TNode fact) { + d_assertionQueue.push_back(fact); + d_assertionSet.insert(fact); +} + void InequalitySolver::explain(TNode literal, std::vector& assumptions) { Assert (false); } diff --git a/src/theory/bv/bv_subtheory_inequality.h b/src/theory/bv/bv_subtheory_inequality.h index 07c561c84..6d1d77c7e 100644 --- a/src/theory/bv/bv_subtheory_inequality.h +++ b/src/theory/bv/bv_subtheory_inequality.h @@ -21,25 +21,30 @@ #include "theory/bv/bv_subtheory.h" #include "theory/bv/bv_inequality_graph.h" +#include "context/cdhashset.h" namespace CVC4 { namespace theory { namespace bv { class InequalitySolver: public SubtheorySolver { + context::CDHashSet d_assertionSet; InequalityGraph d_inequalityGraph; public: InequalitySolver(context::Context* c, TheoryBV* bv) : SubtheorySolver(c, bv), + d_assertionSet(c), d_inequalityGraph(c) {} bool check(Theory::Effort e); void propagate(Theory::Effort e); void explain(TNode literal, std::vector& assumptions); - bool isInequalityTheory() { return true; } - virtual void collectModelInfo(TheoryModel* m) {} + bool isComplete() { return true; } + void collectModelInfo(TheoryModel* m) {} + EqualityStatus getEqualityStatus(TNode a, TNode b); + void assertFact(TNode fact); }; } diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index bc8e39e67..bdf93eadc 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -32,9 +32,6 @@ using namespace CVC4::context; using namespace std; using namespace CVC4::theory::bv::utils; - - - TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo, QuantifiersEngine* qe) : Theory(THEORY_BV, c, u, out, valuation, logicInfo, qe), d_context(c), @@ -122,11 +119,11 @@ void TheoryBV::check(Effort e) } Assert (!ok == inConflict()); - if (!inConflict() && !d_coreSolver.isCoreTheory()) { + if (!inConflict() && !d_coreSolver.isComplete()) { ok = d_inequalitySolver.check(e); } - Assert (!ok == inConflict()); + // Assert (!ok == inConflict()); // if (!inConflict() && !d_coreSolver.isCoreTheory()) { // if (!inConflict() && !d_inequalitySolver.isInequalityTheory()) { // ok = d_bitblastSolver.check(e); @@ -303,6 +300,9 @@ EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b) } EqualityStatus status = d_coreSolver.getEqualityStatus(a, b); + if (status == EQUALITY_UNKNOWN) { + status = d_inequalitySolver.getEqualityStatus(a, b); + } if (status == EQUALITY_UNKNOWN) { status = d_bitblastSolver.getEqualityStatus(a, b); }