From d0f7a3922e38483908d4b86829241a48d8d8db57 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 4 Feb 2020 09:31:22 -0600 Subject: [PATCH] Split base solver from the theory of strings (#3680) --- src/CMakeLists.txt | 2 + src/theory/strings/base_solver.cpp | 427 ++++++++++++++++++++++++++ src/theory/strings/base_solver.h | 191 ++++++++++++ src/theory/strings/theory_strings.cpp | 421 ++++--------------------- src/theory/strings/theory_strings.h | 66 +--- 5 files changed, 678 insertions(+), 429 deletions(-) create mode 100644 src/theory/strings/base_solver.cpp create mode 100644 src/theory/strings/base_solver.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 560f79976..26cc5bbcc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -664,6 +664,8 @@ libcvc4_add_sources( theory/shared_terms_database.h theory/sort_inference.cpp theory/sort_inference.h + theory/strings/base_solver.cpp + theory/strings/base_solver.h theory/strings/infer_info.cpp theory/strings/infer_info.h theory/strings/inference_manager.cpp diff --git a/src/theory/strings/base_solver.cpp b/src/theory/strings/base_solver.cpp new file mode 100644 index 000000000..2f5bc8e2b --- /dev/null +++ b/src/theory/strings/base_solver.cpp @@ -0,0 +1,427 @@ +/********************* */ +/*! \file base_solver.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Base solver for the theory of strings. This class implements term + ** indexing and constant inference for the theory of strings. + **/ + +#include "theory/strings/base_solver.h" + +#include "options/strings_options.h" +#include "theory/strings/theory_strings_rewriter.h" +#include "theory/strings/theory_strings_utils.h" + +using namespace std; +using namespace CVC4::context; +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace strings { + +BaseSolver::BaseSolver(context::Context* c, + context::UserContext* u, + SolverState& s, + InferenceManager& im) + : d_state(s), d_im(im), d_congruent(c) +{ + d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String("")); + d_false = NodeManager::currentNM()->mkConst(false); +} + +BaseSolver::~BaseSolver() {} + +void BaseSolver::checkInit() +{ + // build term index + d_eqcToConst.clear(); + d_eqcToConstBase.clear(); + d_eqcToConstExp.clear(); + d_termIndex.clear(); + d_stringsEqc.clear(); + + std::map ncongruent; + std::map congruent; + eq::EqualityEngine* ee = d_state.getEqualityEngine(); + Assert(d_state.getRepresentative(d_emptyString) == d_emptyString); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee); + while (!eqcs_i.isFinished()) + { + Node eqc = (*eqcs_i); + TypeNode tn = eqc.getType(); + if (!tn.isRegExp()) + { + if (tn.isString()) + { + d_stringsEqc.push_back(eqc); + } + Node var; + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee); + while (!eqc_i.isFinished()) + { + Node n = *eqc_i; + if (n.isConst()) + { + d_eqcToConst[eqc] = n; + d_eqcToConstBase[eqc] = n; + d_eqcToConstExp[eqc] = Node::null(); + } + else if (tn.isInteger()) + { + // do nothing + } + else if (n.getNumChildren() > 0) + { + Kind k = n.getKind(); + if (k != EQUAL) + { + if (d_congruent.find(n) == d_congruent.end()) + { + std::vector c; + Node nc = d_termIndex[k].add(n, 0, d_state, d_emptyString, c); + if (nc != n) + { + // check if we have inferred a new equality by removal of empty + // components + if (n.getKind() == STRING_CONCAT && !d_state.areEqual(nc, n)) + { + std::vector exp; + size_t count[2] = {0, 0}; + while (count[0] < nc.getNumChildren() + || count[1] < n.getNumChildren()) + { + // explain empty prefixes + for (unsigned t = 0; t < 2; t++) + { + Node nn = t == 0 ? nc : n; + while ( + count[t] < nn.getNumChildren() + && (nn[count[t]] == d_emptyString + || d_state.areEqual(nn[count[t]], d_emptyString))) + { + if (nn[count[t]] != d_emptyString) + { + exp.push_back(nn[count[t]].eqNode(d_emptyString)); + } + count[t]++; + } + } + // explain equal components + if (count[0] < nc.getNumChildren()) + { + Assert(count[1] < n.getNumChildren()); + if (nc[count[0]] != n[count[1]]) + { + exp.push_back(nc[count[0]].eqNode(n[count[1]])); + } + count[0]++; + count[1]++; + } + } + // infer the equality + d_im.sendInference(exp, n.eqNode(nc), "I_Norm"); + } + else + { + // mark as congruent : only process if neither has been + // reduced + d_im.markCongruent(nc, n); + } + // this node is congruent to another one, we can ignore it + Trace("strings-process-debug") + << " congruent term : " << n << " (via " << nc << ")" + << std::endl; + d_congruent.insert(n); + congruent[k]++; + } + else if (k == STRING_CONCAT && c.size() == 1) + { + Trace("strings-process-debug") + << " congruent term by singular : " << n << " " << c[0] + << std::endl; + // singular case + if (!d_state.areEqual(c[0], n)) + { + Node ns; + std::vector exp; + // explain empty components + bool foundNEmpty = false; + for (const Node& nc : n) + { + if (d_state.areEqual(nc, d_emptyString)) + { + if (nc != d_emptyString) + { + exp.push_back(nc.eqNode(d_emptyString)); + } + } + else + { + Assert(!foundNEmpty); + ns = nc; + foundNEmpty = true; + } + } + AlwaysAssert(foundNEmpty); + // infer the equality + d_im.sendInference(exp, n.eqNode(ns), "I_Norm_S"); + } + d_congruent.insert(n); + congruent[k]++; + } + else + { + ncongruent[k]++; + } + } + else + { + congruent[k]++; + } + } + } + else + { + if (d_congruent.find(n) == d_congruent.end()) + { + // We mark all but the oldest variable in the equivalence class as + // congruent. + if (var.isNull()) + { + var = n; + } + else if (var > n) + { + Trace("strings-process-debug") + << " congruent variable : " << var << std::endl; + d_congruent.insert(var); + var = n; + } + else + { + Trace("strings-process-debug") + << " congruent variable : " << n << std::endl; + d_congruent.insert(n); + } + } + } + ++eqc_i; + } + } + ++eqcs_i; + } + if (Trace.isOn("strings-process")) + { + for (std::map::iterator it = d_termIndex.begin(); + it != d_termIndex.end(); + ++it) + { + Trace("strings-process") + << " Terms[" << it->first << "] = " << ncongruent[it->first] << "/" + << (congruent[it->first] + ncongruent[it->first]) << std::endl; + } + } +} + +void BaseSolver::checkConstantEquivalenceClasses() +{ + // do fixed point + size_t prevSize = 0; + std::vector vecc; + do + { + vecc.clear(); + Trace("strings-process-debug") + << "Check constant equivalence classes..." << std::endl; + prevSize = d_eqcToConst.size(); + checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc); + } while (!d_im.hasProcessed() && d_eqcToConst.size() > prevSize); +} + +void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti, + std::vector& vecc) +{ + Node n = ti->d_data; + if (!n.isNull()) + { + // construct the constant + Node c = utils::mkNConcat(vecc); + if (!d_state.areEqual(n, c)) + { + if (Trace.isOn("strings-debug")) + { + Trace("strings-debug") + << "Constant eqc : " << c << " for " << n << std::endl; + Trace("strings-debug") << " "; + for (const Node& v : vecc) + { + Trace("strings-debug") << v << " "; + } + Trace("strings-debug") << std::endl; + } + size_t count = 0; + size_t countc = 0; + std::vector exp; + while (count < n.getNumChildren()) + { + while (count < n.getNumChildren() + && d_state.areEqual(n[count], d_emptyString)) + { + d_im.addToExplanation(n[count], d_emptyString, exp); + count++; + } + if (count < n.getNumChildren()) + { + Trace("strings-debug") + << "...explain " << n[count] << " " << vecc[countc] << std::endl; + if (!d_state.areEqual(n[count], vecc[countc])) + { + Node nrr = d_state.getRepresentative(n[count]); + Assert(!d_eqcToConstExp[nrr].isNull()); + d_im.addToExplanation(n[count], d_eqcToConstBase[nrr], exp); + exp.push_back(d_eqcToConstExp[nrr]); + } + else + { + d_im.addToExplanation(n[count], vecc[countc], exp); + } + countc++; + count++; + } + } + // exp contains an explanation of n==c + Assert(countc == vecc.size()); + if (d_state.hasTerm(c)) + { + d_im.sendInference(exp, n.eqNode(c), "I_CONST_MERGE"); + return; + } + else if (!d_im.hasProcessed()) + { + Node nr = d_state.getRepresentative(n); + std::map::iterator it = d_eqcToConst.find(nr); + if (it == d_eqcToConst.end()) + { + Trace("strings-debug") + << "Set eqc const " << n << " to " << c << std::endl; + d_eqcToConst[nr] = c; + d_eqcToConstBase[nr] = n; + d_eqcToConstExp[nr] = utils::mkAnd(exp); + } + else if (c != it->second) + { + // conflict + Trace("strings-debug") + << "Conflict, other constant was " << it->second + << ", this constant was " << c << std::endl; + if (d_eqcToConstExp[nr].isNull()) + { + // n==c ^ n == c' => false + d_im.addToExplanation(n, it->second, exp); + } + else + { + // n==c ^ n == d_eqcToConstBase[nr] == c' => false + exp.push_back(d_eqcToConstExp[nr]); + d_im.addToExplanation(n, d_eqcToConstBase[nr], exp); + } + d_im.sendInference(exp, d_false, "I_CONST_CONFLICT"); + return; + } + else + { + Trace("strings-debug") << "Duplicate constant." << std::endl; + } + } + } + } + for (std::pair& p : ti->d_children) + { + std::map::iterator itc = d_eqcToConst.find(p.first); + if (itc != d_eqcToConst.end()) + { + vecc.push_back(itc->second); + checkConstantEquivalenceClasses(&p.second, vecc); + vecc.pop_back(); + if (d_im.hasProcessed()) + { + break; + } + } + } +} + +bool BaseSolver::isCongruent(Node n) +{ + return d_congruent.find(n) != d_congruent.end(); +} + +Node BaseSolver::getConstantEqc(Node eqc) +{ + std::map::iterator it = d_eqcToConst.find(eqc); + if (it != d_eqcToConst.end()) + { + return it->second; + } + return Node::null(); +} + +Node BaseSolver::explainConstantEqc(Node n, Node eqc, std::vector& exp) +{ + std::map::iterator it = d_eqcToConst.find(eqc); + if (it != d_eqcToConst.end()) + { + if (!d_eqcToConstExp[eqc].isNull()) + { + exp.push_back(d_eqcToConstExp[eqc]); + } + if (!d_eqcToConstBase[eqc].isNull()) + { + d_im.addToExplanation(n, d_eqcToConstBase[eqc], exp); + } + return it->second; + } + return Node::null(); +} + +const std::vector& BaseSolver::getStringEqc() const +{ + return d_stringsEqc; +} + +Node BaseSolver::TermIndex::add(TNode n, + unsigned index, + const SolverState& s, + Node er, + std::vector& c) +{ + if (index == n.getNumChildren()) + { + if (d_data.isNull()) + { + d_data = n; + } + return d_data; + } + Assert(index < n.getNumChildren()); + TNode nir = s.getRepresentative(n[index]); + // if it is empty, and doing CONCAT, ignore + if (nir == er && n.getKind() == STRING_CONCAT) + { + return add(n, index + 1, s, er, c); + } + c.push_back(nir); + return d_children[nir].add(n, index + 1, s, er, c); +} + +} // namespace strings +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/strings/base_solver.h b/src/theory/strings/base_solver.h new file mode 100644 index 000000000..c87a3af9e --- /dev/null +++ b/src/theory/strings/base_solver.h @@ -0,0 +1,191 @@ +/********************* */ +/*! \file base_solver.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Base solver for term indexing and constant inference for the + ** theory of strings. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__STRINGS__BASE_SOLVER_H +#define CVC4__THEORY__STRINGS__BASE_SOLVER_H + +#include "context/cdhashset.h" +#include "context/cdlist.h" +#include "theory/strings/infer_info.h" +#include "theory/strings/inference_manager.h" +#include "theory/strings/normal_form.h" +#include "theory/strings/skolem_cache.h" +#include "theory/strings/solver_state.h" + +namespace CVC4 { +namespace theory { +namespace strings { + +/** The base solver for the theory of strings + * + * This implements techniques for inferring when terms are congruent in the + * current context, and techniques for inferring when equivalence classes + * are equivalent to constants. + */ +class BaseSolver +{ + using NodeSet = context::CDHashSet; + + public: + BaseSolver(context::Context* c, + context::UserContext* u, + SolverState& s, + InferenceManager& im); + ~BaseSolver(); + + //-----------------------inference steps + /** check initial + * + * This function initializes term indices for each strings function symbol. + * One key aspect of this construction is that concat terms are indexed by + * their list of non-empty components. For example, if x = "" is an equality + * asserted in this SAT context, then y ++ x ++ z may be indexed by (y,z). + * This method may infer various facts while building these term indices, for + * instance, based on congruence. An example would be inferring: + * y ++ x ++ z = y ++ z + * if both terms are registered in this SAT context. + * + * This function should be called as a first step of any strategy. + */ + void checkInit(); + /** check constant equivalence classes + * + * This function infers whether CONCAT terms can be simplified to constants. + * For example, if x = "a" and y = "b" are equalities in the current SAT + * context, then we may infer x ++ "c" ++ y is equivalent to "acb". In this + * case, we infer the fact x ++ "c" ++ y = "acb". + */ + void checkConstantEquivalenceClasses(); + //-----------------------end inference steps + + //-----------------------query functions + /** + * Is n congruent to another term in the current context that has not been + * marked congruent? If so, we can ignore n. + * + * Note this and the functions in this block below are valid during a full + * effort check after a call to checkInit. + */ + bool isCongruent(Node n); + /** + * Get the constant that the equivalence class eqc is entailed to be equal + * to, or null if none exist. + */ + Node getConstantEqc(Node eqc); + /** + * Same as above, where the explanation for n = c is added to exp if c is + * the (non-null) return value of this function, where n is a term in the + * equivalence class of eqc. + */ + Node explainConstantEqc(Node n, Node eqc, std::vector& exp); + /** + * Get the set of equivalence classes of type string. + */ + const std::vector& getStringEqc() const; + //-----------------------end query functions + + private: + /** + * A term index that considers terms modulo flattening and constant merging + * for concatenation terms. + */ + class TermIndex + { + public: + /** Add n to this trie + * + * A term is indexed by flattening arguments of concatenation terms, + * and replacing them by (non-empty) constants when possible, for example + * if n is (str.++ x y z) and x = "abc" and y = "" are asserted, then n is + * indexed by ("abc", z). + * + * index: the child of n we are currently processing, + * s : reference to solver state, + * er : the representative of the empty equivalence class. + * + * We store the vector of terms that n was indexed by in the vector c. + */ + Node add(TNode n, + unsigned index, + const SolverState& s, + Node er, + std::vector& c); + /** Clear this trie */ + void clear() { d_children.clear(); } + /** The data at this node of the trie */ + Node d_data; + /** The children of this node of the trie */ + std::map d_children; + }; + /** + * This method is called as we are traversing the term index ti, where vecc + * accumulates the list of constants in the path to ti. If ti has a non-null + * data n, then we have inferred that d_data is equivalent to the + * constant specified by vecc. + */ + void checkConstantEquivalenceClasses(TermIndex* ti, std::vector& vecc); + /** The solver state object */ + SolverState& d_state; + /** The (custom) output channel of the theory of strings */ + InferenceManager& d_im; + /** Commonly used constants */ + Node d_emptyString; + Node d_false; + /** + * A congruence class is a set of terms f( t1 ), ..., f( tn ) where + * t1 = ... = tn. Congruence classes are important since all but + * one of the above terms (the representative of the congruence class) + * can be ignored by the solver. + * + * This set contains a set of nodes that are not representatives of their + * congruence class. This set is used to skip reasoning about terms in + * various inference schemas implemnted by this class. + */ + NodeSet d_congruent; + /** + * The following three vectors are used for tracking constants that each + * equivalence class is entailed to be equal to. + * - The map d_eqcToConst maps (representatives) r of equivalence classes to + * the constant that that equivalence class is entailed to be equal to, + * - The term d_eqcToConstBase[r] is the term in the equivalence class r + * that is entailed to be equal to the constant d_eqcToConst[r], + * - The term d_eqcToConstExp[r] is the explanation of why + * d_eqcToConstBase[r] is equal to d_eqcToConst[r]. + * + * For example, consider the equivalence class { r, x++"a"++y, x++z }, and + * assume x = "" and y = "bb" in the current context. We have that + * d_eqcToConst[r] = "abb", + * d_eqcToConstBase[r] = x++"a"++y + * d_eqcToConstExp[r] = ( x = "" AND y = "bb" ) + * + * This information is computed during checkInit and is used during various + * inference schemas for deriving inferences. + */ + std::map d_eqcToConst; + std::map d_eqcToConstBase; + std::map d_eqcToConstExp; + /** The list of equivalence classes of type string */ + std::vector d_stringsEqc; + /** A term index for each function kind */ + std::map d_termIndex; +}; /* class BaseSolver */ + +} // namespace strings +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__STRINGS__BASE_SOLVER_H */ diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index c3a67aec9..7ebc5f35f 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -62,30 +62,6 @@ std::ostream& operator<<(std::ostream& out, InferStep s) return out; } -Node TheoryStrings::TermIndex::add(TNode n, - unsigned index, - const SolverState& s, - Node er, - std::vector& c) -{ - if( index==n.getNumChildren() ){ - if( d_data.isNull() ){ - d_data = n; - } - return d_data; - }else{ - Assert(index < n.getNumChildren()); - TNode nir = s.getRepresentative(n[index]); - //if it is empty, and doing CONCAT, ignore - if( nir==er && n.getKind()==kind::STRING_CONCAT ){ - return add(n, index + 1, s, er, c); - }else{ - c.push_back( nir ); - return d_children[nir].add(n, index + 1, s, er, c); - } - } -} - TheoryStrings::TheoryStrings(context::Context* c, context::UserContext* u, OutputChannel& out, @@ -101,12 +77,12 @@ TheoryStrings::TheoryStrings(context::Context* c, d_registered_terms_cache(u), d_preproc(&d_sk_cache, u), d_extf_infer_cache(c), - d_congruent(c), d_proxy_var(u), d_proxy_var_to_length(u), d_functionsTerms(c), d_has_extf(c, false), d_has_str_code(false), + d_bsolver(c, u, d_state, d_im), d_regexp_solver(*this, d_state, d_im, c, u), d_input_vars(u), d_input_var_lsum(u), @@ -297,22 +273,10 @@ Node TheoryStrings::getCurrentSubstitutionFor(int effort, return mv; } Node nr = d_state.getRepresentative(n); - std::map::iterator itc = d_eqc_to_const.find(nr); - if (itc != d_eqc_to_const.end()) + Node c = d_bsolver.explainConstantEqc(n, nr, exp); + if (!c.isNull()) { - // constant equivalence classes - Trace("strings-subs") << " constant eqc : " << d_eqc_to_const_exp[nr] - << " " << d_eqc_to_const_base[nr] << " " << nr - << std::endl; - if (!d_eqc_to_const_exp[nr].isNull()) - { - exp.push_back(d_eqc_to_const_exp[nr]); - } - if (!d_eqc_to_const_base[nr].isNull()) - { - d_im.addToExplanation(n, d_eqc_to_const_base[nr], exp); - } - return itc->second; + return c; } else if (effort >= 1 && n.getType().isString()) { @@ -1244,266 +1208,6 @@ void TheoryStrings::assertPendingFact(Node atom, bool polarity, Node exp) { Trace("strings-pending-debug") << " Finished collect terms" << std::endl; } -void TheoryStrings::checkInit() { - //build term index - d_eqc_to_const.clear(); - d_eqc_to_const_base.clear(); - d_eqc_to_const_exp.clear(); - d_eqc_to_len_term.clear(); - d_term_index.clear(); - d_strings_eqc.clear(); - - std::map< Kind, unsigned > ncongruent; - std::map< Kind, unsigned > congruent; - d_emptyString_r = d_state.getRepresentative(d_emptyString); - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); - while( !eqcs_i.isFinished() ){ - Node eqc = (*eqcs_i); - TypeNode tn = eqc.getType(); - if( !tn.isRegExp() ){ - if( tn.isString() ){ - d_strings_eqc.push_back( eqc ); - } - Node var; - eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine ); - while( !eqc_i.isFinished() ) { - Node n = *eqc_i; - if( n.isConst() ){ - d_eqc_to_const[eqc] = n; - d_eqc_to_const_base[eqc] = n; - d_eqc_to_const_exp[eqc] = Node::null(); - }else if( tn.isInteger() ){ - if( n.getKind()==kind::STRING_LENGTH ){ - Node nr = d_state.getRepresentative(n[0]); - d_eqc_to_len_term[nr] = n[0]; - } - }else if( n.getNumChildren()>0 ){ - Kind k = n.getKind(); - if( k!=kind::EQUAL ){ - if( d_congruent.find( n )==d_congruent.end() ){ - std::vector< Node > c; - Node nc = d_term_index[k].add(n, 0, d_state, d_emptyString_r, c); - if( nc!=n ){ - //check if we have inferred a new equality by removal of empty components - if (n.getKind() == kind::STRING_CONCAT - && !d_state.areEqual(nc, n)) - { - std::vector< Node > exp; - unsigned count[2] = { 0, 0 }; - while( count[0]hasFunctionKind(n.getKind())) - { - //mark as congruent : only process if neither has been reduced - getExtTheory()->markCongruent( nc, n ); - } - //this node is congruent to another one, we can ignore it - Trace("strings-process-debug") - << " congruent term : " << n << " (via " << nc << ")" - << std::endl; - d_congruent.insert( n ); - congruent[k]++; - }else if( k==kind::STRING_CONCAT && c.size()==1 ){ - Trace("strings-process-debug") << " congruent term by singular : " << n << " " << c[0] << std::endl; - //singular case - if (!d_state.areEqual(c[0], n)) - { - Node ns; - std::vector< Node > exp; - //explain empty components - bool foundNEmpty = false; - for( unsigned i=0; i n) - { - Trace("strings-process-debug") - << " congruent variable : " << var << std::endl; - d_congruent.insert(var); - var = n; - } - else - { - Trace("strings-process-debug") - << " congruent variable : " << n << std::endl; - d_congruent.insert(n); - } - } - } - ++eqc_i; - } - } - ++eqcs_i; - } - if( Trace.isOn("strings-process") ){ - for( std::map< Kind, TermIndex >::iterator it = d_term_index.begin(); it != d_term_index.end(); ++it ){ - Trace("strings-process") << " Terms[" << it->first << "] = " << ncongruent[it->first] << "/" << (congruent[it->first]+ncongruent[it->first]) << std::endl; - } - } -} - -void TheoryStrings::checkConstantEquivalenceClasses() -{ - // do fixed point - unsigned prevSize; - std::vector vecc; - do - { - vecc.clear(); - Trace("strings-process-debug") << "Check constant equivalence classes..." - << std::endl; - prevSize = d_eqc_to_const.size(); - checkConstantEquivalenceClasses(&d_term_index[kind::STRING_CONCAT], vecc); - } while (!d_im.hasProcessed() && d_eqc_to_const.size() > prevSize); -} - -void TheoryStrings::checkConstantEquivalenceClasses( TermIndex* ti, std::vector< Node >& vecc ) { - Node n = ti->d_data; - if( !n.isNull() ){ - //construct the constant - Node c = utils::mkNConcat(vecc); - if (!d_state.areEqual(n, c)) - { - Trace("strings-debug") << "Constant eqc : " << c << " for " << n << std::endl; - Trace("strings-debug") << " "; - for( unsigned i=0; i exp; - while( count::iterator it = d_eqc_to_const.find( nr ); - if( it==d_eqc_to_const.end() ){ - Trace("strings-debug") << "Set eqc const " << n << " to " << c << std::endl; - d_eqc_to_const[nr] = c; - d_eqc_to_const_base[nr] = n; - d_eqc_to_const_exp[nr] = utils::mkAnd(exp); - }else if( c!=it->second ){ - //conflict - Trace("strings-debug") << "Conflict, other constant was " << it->second << ", this constant was " << c << std::endl; - if( d_eqc_to_const_exp[nr].isNull() ){ - // n==c ^ n == c' => false - d_im.addToExplanation(n, it->second, exp); - }else{ - // n==c ^ n == d_eqc_to_const_base[nr] == c' => false - exp.push_back( d_eqc_to_const_exp[nr] ); - d_im.addToExplanation(n, d_eqc_to_const_base[nr], exp); - } - d_im.sendInference(exp, d_false, "I_CONST_CONFLICT"); - return; - }else{ - Trace("strings-debug") << "Duplicate constant." << std::endl; - } - } - } - } - for( std::map< TNode, TermIndex >::iterator it = ti->d_children.begin(); it != ti->d_children.end(); ++it ){ - std::map< Node, Node >::iterator itc = d_eqc_to_const.find( it->first ); - if( itc!=d_eqc_to_const.end() ){ - vecc.push_back( itc->second ); - checkConstantEquivalenceClasses( &it->second, vecc ); - vecc.pop_back(); - if (d_im.hasProcessed()) - { - break; - } - } - } -} - void TheoryStrings::checkExtfEval( int effort ) { Trace("strings-extf-list") << "Active extended functions, effort=" << effort << " : " << std::endl; d_extf_info_tmp.clear(); @@ -1515,11 +1219,7 @@ void TheoryStrings::checkExtfEval( int effort ) { // Setup information about n, including if it is equal to a constant. ExtfInfoTmp& einfo = d_extf_info_tmp[n]; Node r = d_state.getRepresentative(n); - std::map::iterator itcit = d_eqc_to_const.find(r); - if (itcit != d_eqc_to_const.end()) - { - einfo.d_const = itcit->second; - } + einfo.d_const = d_bsolver.getConstantEqc(r); // Get the current values of the children of n. // Notice that we look up the value of the direct children of n, and not // their free variables. In other words, given a term: @@ -1718,16 +1418,8 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef { // otherwise, must explain via base node Node r = d_state.getRepresentative(n); - // we have that: - // d_eqc_to_const_exp[r] => d_eqc_to_const_base[r] = in.d_const - // thus: - // n = d_eqc_to_const_base[r] ^ d_eqc_to_const_exp[r] => n = in.d_const - Assert(d_eqc_to_const_base.find(r) != d_eqc_to_const_base.end()); - d_im.addToExplanation(n, d_eqc_to_const_base[r], in.d_exp); - Assert(d_eqc_to_const_exp.find(r) != d_eqc_to_const_exp.end()); - in.d_exp.insert(in.d_exp.end(), - d_eqc_to_const_exp[r].begin(), - d_eqc_to_const_exp[r].end()); + // explain using the base solver + d_bsolver.explainConstantEqc(n, r, in.d_exp); } // d_extf_infer_cache stores whether we have made the inferences associated @@ -1932,15 +1624,6 @@ Node TheoryStrings::getSymbolicDefinition(Node n, std::vector& exp) const } } -Node TheoryStrings::getConstantEqc( Node eqc ) { - std::map< Node, Node >::iterator it = d_eqc_to_const.find( eqc ); - if( it!=d_eqc_to_const.end() ){ - return it->second; - }else{ - return Node::null(); - } -} - void TheoryStrings::debugPrintFlatForms( const char * tc ){ for( unsigned k=0; k::iterator itc = d_eqc_to_const.find( eqc ); - if( itc!=d_eqc_to_const.end() ){ - Trace( tc ) << " C: " << itc->second; + Node c = d_bsolver.getConstantEqc(eqc); + if (!c.isNull()) + { + Trace(tc) << " C: " << c; if( d_eqc[eqc].size()>1 ){ Trace( tc ) << std::endl; } @@ -1962,11 +1646,14 @@ void TheoryStrings::debugPrintFlatForms( const char * tc ){ Trace( tc ) << " "; for( unsigned j=0; jsecond; - }else{ + if (!fcc.isNull()) + { + Trace(tc) << fcc; + } + else + { Trace( tc ) << fc; } } @@ -2009,17 +1696,18 @@ void TheoryStrings::checkCycles() d_flat_form.clear(); d_flat_form_index.clear(); d_eqc.clear(); - //rebuild strings eqc based on acyclic ordering - std::vector< Node > eqc; - eqc.insert( eqc.end(), d_strings_eqc.begin(), d_strings_eqc.end() ); + // Rebuild strings eqc based on acyclic ordering, first copy the equivalence + // classes from the base solver. + std::vector eqc = d_bsolver.getStringEqc(); d_strings_eqc.clear(); if( options::stringBinaryCsp() ){ //sort: process smallest constants first (necessary if doing binary splits) sortConstLength scl; for( unsigned i=0; i::iterator itc = d_eqc_to_const.find( eqc[i] ); - if( itc!=d_eqc_to_const.end() ){ - scl.d_const_length[eqc[i]] = itc->second.getConst().size(); + Node ci = d_bsolver.getConstantEqc(eqc[i]); + if (!ci.isNull()) + { + scl.d_const_length[eqc[i]] = ci.getConst().size(); } } std::sort( eqc.begin(), eqc.end(), scl ); @@ -2049,7 +1737,7 @@ void TheoryStrings::checkFlatForms() //(1) approximate equality by containment, infer conflicts for (const Node& eqc : d_strings_eqc) { - Node c = getConstantEqc(eqc); + Node c = d_bsolver.getConstantEqc(eqc); if (!c.isNull()) { // if equivalence class is constant, all component constants in flat forms @@ -2071,13 +1759,7 @@ void TheoryStrings::checkFlatForms() // conflict, explanation is n = base ^ base = c ^ relevant portion // of ( n = f[n] ) std::vector exp; - Assert(d_eqc_to_const_base.find(eqc) != d_eqc_to_const_base.end()); - d_im.addToExplanation(n, d_eqc_to_const_base[eqc], exp); - Assert(d_eqc_to_const_exp.find(eqc) != d_eqc_to_const_exp.end()); - if (!d_eqc_to_const_exp[eqc].isNull()) - { - exp.push_back(d_eqc_to_const_exp[eqc]); - } + d_bsolver.explainConstantEqc(n, eqc, exp); for (int e = firstc; e <= lastc; e++) { if (d_flat_form[n][e].isConst()) @@ -2216,7 +1898,7 @@ void TheoryStrings::checkFlatForm(std::vector& eqc, else { Node curr = d_flat_form[a][count]; - Node curr_c = getConstantEqc(curr); + Node curr_c = d_bsolver.getConstantEqc(curr); Node ac = a[d_flat_form_index[a][count]]; std::vector lexp; Node lcurr = d_state.getLength(ac, lexp); @@ -2254,7 +1936,7 @@ void TheoryStrings::checkFlatForm(std::vector& eqc, Node bc = b[d_flat_form_index[b][count]]; inelig.push_back(b); Assert(!d_state.areEqual(curr, cc)); - Node cc_c = getConstantEqc(cc); + Node cc_c = d_bsolver.getConstantEqc(cc); if (!curr_c.isNull() && !cc_c.isNull()) { // check for constant conflict @@ -2263,10 +1945,8 @@ void TheoryStrings::checkFlatForm(std::vector& eqc, cc_c, curr_c, index, isRev); if (s.isNull()) { - d_im.addToExplanation(ac, d_eqc_to_const_base[curr], exp); - d_im.addToExplanation(d_eqc_to_const_exp[curr], exp); - d_im.addToExplanation(bc, d_eqc_to_const_base[cc], exp); - d_im.addToExplanation(d_eqc_to_const_exp[cc], exp); + d_bsolver.explainConstantEqc(ac, curr, exp); + d_bsolver.explainConstantEqc(bc, cc, exp); conc = d_false; infType = FlatFormInfer::CONST; break; @@ -2386,25 +2066,32 @@ Node TheoryStrings::checkCycles( Node eqc, std::vector< Node >& curr, std::vecto eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine ); while( !eqc_i.isFinished() ) { Node n = (*eqc_i); - if( d_congruent.find( n )==d_congruent.end() ){ + if (!d_bsolver.isCongruent(n)) + { if( n.getKind() == kind::STRING_CONCAT ){ Trace("strings-cycle") << eqc << " check term : " << n << " in " << eqc << std::endl; - if( eqc!=d_emptyString_r ){ + if (eqc != d_emptyString) + { d_eqc[eqc].push_back( n ); } for( unsigned i=0; i exp; exp.push_back( n.eqNode( d_emptyString ) ); d_im.sendInference( exp, n[i].eqNode(d_emptyString), "I_CYCLE_E"); return Node::null(); } - }else{ - if( nr!=d_emptyString_r ){ + } + else + { + if (nr != d_emptyString) + { d_flat_form[n].push_back( nr ); d_flat_form_index[n].push_back( i ); } @@ -2460,7 +2147,7 @@ void TheoryStrings::checkRegisterTermsPreNormalForm() while (!eqc_i.isFinished()) { Node n = (*eqc_i); - if (d_congruent.find(n) == d_congruent.end()) + if (!d_bsolver.isCongruent(n)) { registerTerm(n, 2); } @@ -2686,7 +2373,8 @@ void TheoryStrings::getNormalForms(Node eqc, eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine ); while( !eqc_i.isFinished() ){ Node n = (*eqc_i); - if( d_congruent.find( n )==d_congruent.end() ){ + if (!d_bsolver.isCongruent(n)) + { if (n.getKind() == CONST_STRING || n.getKind() == STRING_CONCAT) { Trace("strings-process-debug") << "Get Normal Form : Process term " << n << " in eqc " << eqc << std::endl; @@ -2867,7 +2555,7 @@ void TheoryStrings::getNormalForms(Node eqc, } //if equivalence class is constant, approximate as containment, infer conflicts - Node c = getConstantEqc( eqc ); + Node c = d_bsolver.getConstantEqc(eqc); if( !c.isNull() ){ Trace("strings-solve") << "Eqc is constant " << c << std::endl; for (unsigned i = 0, size = normal_forms.size(); i < size; i++) @@ -2882,12 +2570,7 @@ void TheoryStrings::getNormalForms(Node eqc, Trace("strings-solve") << "Normal form for " << n << " cannot be contained in constant " << c << std::endl; //conflict, explanation is n = base ^ base = c ^ relevant porition of ( n = N[n] ) std::vector< Node > exp; - Assert(d_eqc_to_const_base.find(eqc) != d_eqc_to_const_base.end()); - d_im.addToExplanation(n, d_eqc_to_const_base[eqc], exp); - Assert(d_eqc_to_const_exp.find(eqc) != d_eqc_to_const_exp.end()); - if( !d_eqc_to_const_exp[eqc].isNull() ){ - exp.push_back( d_eqc_to_const_exp[eqc] ); - } + d_bsolver.explainConstantEqc(n, eqc, exp); //TODO: this can be minimized based on firstc/lastc, normal_forms_exp_depend exp.insert(exp.end(), nf.d_exp.begin(), nf.d_exp.end()); Node conc = d_false; @@ -3850,7 +3533,7 @@ int TheoryStrings::processSimpleDeq( std::vector< Node >& nfi, std::vector< Node { for (unsigned i = 0; i < 2; i++) { - Node c = getConstantEqc(i == 0 ? ni : nj); + Node c = d_bsolver.getConstantEqc(i == 0 ? ni : nj); if (!c.isNull()) { int findex, lindex; @@ -4465,8 +4148,8 @@ void TheoryStrings::runInferStep(InferStep s, int effort) Trace("strings-process") << "..." << std::endl; switch (s) { - case CHECK_INIT: checkInit(); break; - case CHECK_CONST_EQC: checkConstantEquivalenceClasses(); break; + case CHECK_INIT: d_bsolver.checkInit(); break; + case CHECK_CONST_EQC: d_bsolver.checkConstantEquivalenceClasses(); break; case CHECK_EXTF_EVAL: checkExtfEval(effort); break; case CHECK_CYCLES: checkCycles(); break; case CHECK_FLAT_FORMS: checkFlatForms(); break; diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index ce92ada86..3b53fcded 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -24,6 +24,7 @@ #include "expr/attribute.h" #include "expr/node_trie.h" #include "theory/decision_manager.h" +#include "theory/strings/base_solver.h" #include "theory/strings/infer_info.h" #include "theory/strings/inference_manager.h" #include "theory/strings/normal_form.h" @@ -244,30 +245,6 @@ class TheoryStrings : public Theory { NodeSet d_extf_infer_cache; std::vector< Node > d_empty_vec; private: - NodeSet d_congruent; - /** - * The following three vectors are used for tracking constants that each - * equivalence class is entailed to be equal to. - * - The map d_eqc_to_const maps (representatives) r of equivalence classes to - * the constant that that equivalence class is entailed to be equal to, - * - The term d_eqc_to_const_base[r] is the term in the equivalence class r - * that is entailed to be equal to the constant d_eqc_to_const[r], - * - The term d_eqc_to_const_exp[r] is the explanation of why - * d_eqc_to_const_base[r] is equal to d_eqc_to_const[r]. - * - * For example, consider the equivalence class { r, x++"a"++y, x++z }, and - * assume x = "" and y = "bb" in the current context. We have that - * d_eqc_to_const[r] = "abb", - * d_eqc_to_const_base[r] = x++"a"++y - * d_eqc_to_const_exp[r] = ( x = "" AND y = "bb" ) - * - * This information is computed during checkInit and is used during various - * inference schemas for deriving inferences. - */ - std::map< Node, Node > d_eqc_to_const; - std::map< Node, Node > d_eqc_to_const_base; - std::map< Node, Node > d_eqc_to_const_exp; - Node getConstantEqc( Node eqc ); /** * Get the current substitution for term n. * @@ -285,19 +262,6 @@ private: std::map< Node, Node > d_eqc_to_len_term; std::vector< Node > d_strings_eqc; - Node d_emptyString_r; - class TermIndex { - public: - Node d_data; - std::map< TNode, TermIndex > d_children; - Node add(TNode n, - unsigned index, - const SolverState& s, - Node er, - std::vector& c); - void clear(){ d_children.clear(); } - }; - std::map< Kind, TermIndex > d_term_index; //list of non-congruent concat terms in each eqc std::map< Node, std::vector< Node > > d_eqc; std::map< Node, std::vector< Node > > d_flat_form; @@ -360,7 +324,6 @@ private: /** cache of all skolems */ SkolemCache d_sk_cache; - void checkConstantEquivalenceClasses( TermIndex* ti, std::vector< Node >& vecc ); /** Get proxy variable * * If this method returns the proxy variable for (string) term n if one @@ -619,6 +582,11 @@ private: // Symbolic Regular Expression private: + /** + * The base solver, responsible for reasoning about congruent terms and + * inferring constants for equivalence classes. + */ + BaseSolver d_bsolver; /** regular expression solver module */ RegExpSolver d_regexp_solver; /** regular expression elimination module */ @@ -688,28 +656,6 @@ private: private: //-----------------------inference steps - /** check initial - * - * This function initializes term indices for each strings function symbol. - * One key aspect of this construction is that concat terms are indexed by - * their list of non-empty components. For example, if x = "" is an equality - * asserted in this SAT context, then y ++ x ++ z may be indexed by (y,z). - * This method may infer various facts while building these term indices, for - * instance, based on congruence. An example would be inferring: - * y ++ x ++ z = y ++ z - * if both terms are registered in this SAT context. - * - * This function should be called as a first step of any strategy. - */ - void checkInit(); - /** check constant equivalence classes - * - * This function infers whether CONCAT terms can be simplified to constants. - * For example, if x = "a" and y = "b" are equalities in the current SAT - * context, then we may infer x ++ "c" ++ y is equivalent to "acb". In this - * case, we infer the fact x ++ "c" ++ y = "acb". - */ - void checkConstantEquivalenceClasses(); /** check extended functions evaluation * * This applies "context-dependent simplification" for all active extended -- 2.30.2