From: Andres Noetzli Date: Mon, 15 Nov 2021 15:49:26 +0000 (-0800) Subject: [Strings] Minor refactor of eager solver (#7628) X-Git-Tag: cvc5-1.0.0~818 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=cf1a63a7dc613099129ff612ca11845a3eb3aa0f;p=cvc5.git [Strings] Minor refactor of eager solver (#7628) This moves code that is not strictly related to the eager solver out of the eager solver and into TheoryStrings. This is cleaner and makes it easier to enable/disable the eager solver. --- diff --git a/src/theory/strings/eager_solver.cpp b/src/theory/strings/eager_solver.cpp index 21fdd6fa2..ce47ff4fc 100644 --- a/src/theory/strings/eager_solver.cpp +++ b/src/theory/strings/eager_solver.cpp @@ -37,41 +37,30 @@ EagerSolver::~EagerSolver() {} void EagerSolver::eqNotifyNewClass(TNode t) { Kind k = t.getKind(); - if (k == STRING_LENGTH || k == STRING_TO_CODE) + if (k == STRING_LENGTH) { - eq::EqualityEngine* ee = d_state.getEqualityEngine(); - Node r = ee->getRepresentative(t[0]); - EqcInfo* ei = d_state.getOrMakeEqcInfo(r); - if (k == STRING_LENGTH) + // also assume it as upper/lower bound as applicable for the equivalence + // class info of t. + EqcInfo* eil = nullptr; + for (size_t i = 0; i < 2; i++) { - ei->d_lengthTerm = t; - // also assume it as upper/lower bound as applicable for the equivalence - // class info of t. - EqcInfo* eil = nullptr; - for (size_t i = 0; i < 2; i++) + Node b = getBoundForLength(t, i == 0); + if (b.isNull()) { - Node b = getBoundForLength(t, i == 0); - if (b.isNull()) - { - continue; - } - if (eil == nullptr) - { - eil = d_state.getOrMakeEqcInfo(t); - } - if (i == 0) - { - eil->d_firstBound = t; - } - else if (i == 1) - { - eil->d_secondBound = t; - } + continue; + } + if (eil == nullptr) + { + eil = d_state.getOrMakeEqcInfo(t); + } + if (i == 0) + { + eil->d_firstBound = t; + } + else if (i == 1) + { + eil->d_secondBound = t; } - } - else - { - ei->d_codeTerm = t[0]; } } else if (t.isConst()) @@ -90,15 +79,10 @@ void EagerSolver::eqNotifyNewClass(TNode t) } } -void EagerSolver::eqNotifyMerge(TNode t1, TNode t2) +void EagerSolver::eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2) { - EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false); - if (e2 == nullptr) - { - return; - } - // always create it if e2 was non-null - EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1); + Assert(e1 != nullptr); + Assert(e2 != nullptr); // check for conflict Node conf = checkForMergeConflict(t1, t2, e1, e2); if (!conf.isNull()) @@ -109,33 +93,6 @@ void EagerSolver::eqNotifyMerge(TNode t1, TNode t2) d_state.setPendingMergeConflict(conf, id); return; } - // add information from e2 to e1 - if (!e2->d_lengthTerm.get().isNull()) - { - e1->d_lengthTerm.set(e2->d_lengthTerm); - } - if (!e2->d_codeTerm.get().isNull()) - { - e1->d_codeTerm.set(e2->d_codeTerm); - } - if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get()) - { - e1->d_cardinalityLemK.set(e2->d_cardinalityLemK); - } - if (!e2->d_normalizedLength.get().isNull()) - { - e1->d_normalizedLength.set(e2->d_normalizedLength); - } -} - -void EagerSolver::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) -{ - if (t1.getType().isStringLike()) - { - // store disequalities between strings, may need to check if their lengths - // are equal/disequal - d_state.addDisequality(t1, t2); - } } void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc) diff --git a/src/theory/strings/eager_solver.h b/src/theory/strings/eager_solver.h index 03fb0ff63..4181a15c3 100644 --- a/src/theory/strings/eager_solver.h +++ b/src/theory/strings/eager_solver.h @@ -46,9 +46,7 @@ class EagerSolver : protected EnvObj /** called when a new equivalence class is created */ void eqNotifyNewClass(TNode t); /** called when two equivalence classes merge */ - void eqNotifyMerge(TNode t1, TNode t2); - /** called when two equivalence classes are made disequal */ - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); + void eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2); /** notify fact, called when a fact is asserted to theory of strings */ void notifyFact(TNode atom, bool polarity, TNode fact, bool isInternal); diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index ed00758a8..caeb8065e 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -749,10 +749,63 @@ void TheoryStrings::eqNotifyNewClass(TNode t){ Trace("strings-debug") << "New length eqc : " << t << std::endl; //we care about the length of this string d_termReg.registerTerm(t[0], 1); + + eq::EqualityEngine* ee = d_state.getEqualityEngine(); + Node r = ee->getRepresentative(t[0]); + EqcInfo* ei = d_state.getOrMakeEqcInfo(r); + if (k == STRING_LENGTH) + { + ei->d_lengthTerm = t; + } + else + { + ei->d_codeTerm = t[0]; + } } d_eagerSolver.eqNotifyNewClass(t); } +void TheoryStrings::eqNotifyMerge(TNode t1, TNode t2) +{ + EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false); + if (e2 == nullptr) + { + return; + } + // always create it if e2 was non-null + EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1); + + d_eagerSolver.eqNotifyMerge(e1, t1, e2, t2); + + // add information from e2 to e1 + if (!e2->d_lengthTerm.get().isNull()) + { + e1->d_lengthTerm.set(e2->d_lengthTerm); + } + if (!e2->d_codeTerm.get().isNull()) + { + e1->d_codeTerm.set(e2->d_codeTerm); + } + if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get()) + { + e1->d_cardinalityLemK.set(e2->d_cardinalityLemK); + } + if (!e2->d_normalizedLength.get().isNull()) + { + e1->d_normalizedLength.set(e2->d_normalizedLength); + } +} + +void TheoryStrings::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) +{ + if (t1.getType().isStringLike()) + { + // store disequalities between strings, may need to check if their lengths + // are equal/disequal + d_state.addDisequality(t1, t2); + } +} + void TheoryStrings::addCarePairs(TNodeTrie* t1, TNodeTrie* t2, unsigned arity, diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index dbb04580f..21db7da0c 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -108,6 +108,10 @@ class TheoryStrings : public Theory { void conflict(TNode a, TNode b); /** called when a new equivalence class is created */ void eqNotifyNewClass(TNode t); + /** Called just after the merge of two equivalence classes */ + void eqNotifyMerge(TNode t1, TNode t2); + /** called a disequality is added */ + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); /** preprocess rewrite */ TrustNode ppRewrite(TNode atom, std::vector& lems) override; /** Collect model values in m based on the relevant terms given by termSet */ @@ -118,9 +122,7 @@ class TheoryStrings : public Theory { /** NotifyClass for equality engine */ class NotifyClass : public eq::EqualityEngineNotify { public: - NotifyClass(TheoryStrings& ts) : d_str(ts), d_eagerSolver(ts.d_eagerSolver) - { - } + NotifyClass(TheoryStrings& ts) : d_str(ts) {} bool eqNotifyTriggerPredicate(TNode predicate, bool value) override { Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate @@ -156,19 +158,17 @@ class TheoryStrings : public Theory { { Debug("strings") << "NotifyClass::eqNotifyMerge(" << t1 << ", " << t2 << std::endl; - d_eagerSolver.eqNotifyMerge(t1, t2); + d_str.eqNotifyMerge(t1, t2); } void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override { Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl; - d_eagerSolver.eqNotifyDisequal(t1, t2, reason); + d_str.eqNotifyDisequal(t1, t2, reason); } private: /** The theory of strings object to notify */ TheoryStrings& d_str; - /** The eager solver of the theory of strings */ - EagerSolver& d_eagerSolver; };/* class TheoryStrings::NotifyClass */ /** compute care graph */ void computeCareGraph() override;