From f118ce702d191a5250569611da530dbf67337fb1 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 30 Jan 2020 11:53:54 -0600 Subject: [PATCH] Move disequality list to solver state in strings (#3678) --- src/theory/strings/solver_state.cpp | 52 +++++++++++++++++++++++++++ src/theory/strings/solver_state.h | 18 ++++++++++ src/theory/strings/theory_strings.cpp | 46 ++++-------------------- src/theory/strings/theory_strings.h | 2 -- 4 files changed, 77 insertions(+), 41 deletions(-) diff --git a/src/theory/strings/solver_state.cpp b/src/theory/strings/solver_state.cpp index 66ae8d6bc..664b56b17 100644 --- a/src/theory/strings/solver_state.cpp +++ b/src/theory/strings/solver_state.cpp @@ -143,6 +143,7 @@ SolverState::SolverState(context::Context* c, Valuation& v) : d_context(c), d_ee(ee), + d_eeDisequalities(c), d_valuation(v), d_conflict(c, false), d_pendingConflict(c) @@ -200,6 +201,57 @@ bool SolverState::areDisequal(Node a, Node b) const eq::EqualityEngine* SolverState::getEqualityEngine() const { return &d_ee; } +const context::CDList& SolverState::getDisequalityList() const +{ + return d_eeDisequalities; +} + +void SolverState::eqNotifyPreMerge(TNode t1, TNode t2) +{ + EqcInfo* e2 = getOrMakeEqcInfo(t2, false); + if (e2) + { + EqcInfo* e1 = getOrMakeEqcInfo(t1); + // add information from e2 to e1 + if (!e2->d_lengthTerm.get().isNull()) + { + e1->d_lengthTerm.set(e2->d_lengthTerm); + } + if (!e2->d_codeTerm.get().isNull()) + { + e1->d_codeTerm.set(e2->d_codeTerm); + } + if (!e2->d_prefixC.get().isNull()) + { + setPendingConflictWhen( + e1->addEndpointConst(e2->d_prefixC, Node::null(), false)); + } + if (!e2->d_suffixC.get().isNull()) + { + setPendingConflictWhen( + e1->addEndpointConst(e2->d_suffixC, Node::null(), true)); + } + if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get()) + { + e1->d_cardinalityLemK.set(e2->d_cardinalityLemK); + } + if (!e2->d_normalizedLength.get().isNull()) + { + e1->d_normalizedLength.set(e2->d_normalizedLength); + } + } +} + +void SolverState::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) +{ + if (t1.getType().isString()) + { + // store disequalities between strings, may need to check if their lengths + // are equal/disequal + d_eeDisequalities.push_back(t1.eqNode(t2)); + } +} + EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake) { std::map::iterator eqc_i = d_eqcInfo.find(eqc); diff --git a/src/theory/strings/solver_state.h b/src/theory/strings/solver_state.h index 46d198d36..cb17e6d1b 100644 --- a/src/theory/strings/solver_state.h +++ b/src/theory/strings/solver_state.h @@ -88,6 +88,8 @@ class EqcInfo */ class SolverState { + typedef context::CDList NodeList; + public: SolverState(context::Context* c, eq::EqualityEngine& ee, Valuation& v); ~SolverState(); @@ -111,7 +113,18 @@ class SolverState bool areDisequal(Node a, Node b) const; /** get equality engine */ eq::EqualityEngine* getEqualityEngine() const; + /** + * Get the list of disequalities that are currently asserted to the equality + * engine. + */ + const context::CDList& getDisequalityList() const; //-------------------------------------- end equality information + //-------------------------------------- notifications for equalities + /** called when two equivalence classes will merge */ + void eqNotifyPreMerge(TNode t1, TNode t2); + /** called when two equivalence classes are made disequal */ + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); + //-------------------------------------- end notifications for equalities //------------------------------------------ conflicts /** * Set that the current state of the solver is in conflict. This should be @@ -188,6 +201,11 @@ class SolverState context::Context* d_context; /** Reference to equality engine of the theory of strings. */ eq::EqualityEngine& d_ee; + /** + * The (SAT-context-dependent) list of disequalities that have been asserted + * to the equality engine above. + */ + NodeList d_eeDisequalities; /** Reference to the valuation of the theory of strings */ Valuation& d_valuation; /** Are we in conflict? */ diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 755e6b4df..152160cde 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -101,7 +101,6 @@ TheoryStrings::TheoryStrings(context::Context* c, d_registered_terms_cache(u), d_preproc(&d_sk_cache, u), d_extf_infer_cache(c), - d_ee_disequalities(c), d_congruent(c), d_proxy_var(u), d_proxy_var_to_length(u), @@ -1075,37 +1074,7 @@ void TheoryStrings::eqNotifyNewClass(TNode t){ /** called when two equivalance classes will merge */ void TheoryStrings::eqNotifyPreMerge(TNode t1, TNode t2){ - EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false); - if( e2 ){ - EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1); - //add information from e2 to e1 - if (!e2->d_lengthTerm.get().isNull()) - { - e1->d_lengthTerm.set(e2->d_lengthTerm); - } - if (!e2->d_codeTerm.get().isNull()) - { - e1->d_codeTerm.set(e2->d_codeTerm); - } - if (!e2->d_prefixC.get().isNull()) - { - d_state.setPendingConflictWhen( - e1->addEndpointConst(e2->d_prefixC, Node::null(), false)); - } - if (!e2->d_suffixC.get().isNull()) - { - d_state.setPendingConflictWhen( - e1->addEndpointConst(e2->d_suffixC, Node::null(), true)); - } - if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get()) - { - e1->d_cardinalityLemK.set(e2->d_cardinalityLemK); - } - if (!e2->d_normalizedLength.get().isNull()) - { - e1->d_normalizedLength.set(e2->d_normalizedLength); - } - } + d_state.eqNotifyPreMerge(t1, t2); } /** called when two equivalance classes have merged */ @@ -1115,10 +1084,7 @@ void TheoryStrings::eqNotifyPostMerge(TNode t1, TNode t2) { /** called when two equivalance classes are disequal */ void TheoryStrings::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) { - if( t1.getType().isString() ){ - //store disequalities between strings, may need to check if their lengths are equal/disequal - d_ee_disequalities.push_back( t1.eqNode( t2 ) ); - } + d_state.eqNotifyDisequal(t1, t2, reason); } void TheoryStrings::addCarePairs(TNodeTrie* t1, @@ -4139,10 +4105,12 @@ void TheoryStrings::checkNormalFormsDeq() std::vector< std::vector< Node > > cols; std::vector< Node > lts; std::map< Node, std::map< Node, bool > > processed; - + + const NodeList& deqs = d_state.getDisequalityList(); + //for each pair of disequal strings, must determine whether their lengths are equal or disequal - for( NodeList::const_iterator id = d_ee_disequalities.begin(); id != d_ee_disequalities.end(); ++id ) { - Node eq = *id; + for (const Node& eq : deqs) + { Node n[2]; for( unsigned i=0; i<2; i++ ){ n[i] = d_equalityEngine.getRepresentative( eq[i] ); diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 990461027..ce92ada86 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -243,8 +243,6 @@ class TheoryStrings : public Theory { // extended functions inferences cache NodeSet d_extf_infer_cache; std::vector< Node > d_empty_vec; - // - NodeList d_ee_disequalities; private: NodeSet d_congruent; /** -- 2.30.2