From 9b20af281db3e77a25540305dfb73cbe56519f75 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Sat, 22 Feb 2020 01:05:43 -0600 Subject: [PATCH] Move cardinality inference scheme to base solver in strings (#3792) Moves handling of cardinality to the base solver, refactors how cardinality is accessed. No intended behavior change in this commit. --- src/theory/strings/base_solver.cpp | 133 ++++++++++++++++++ src/theory/strings/base_solver.h | 9 ++ src/theory/strings/regexp_operation.cpp | 2 +- src/theory/strings/regexp_operation.h | 2 +- src/theory/strings/theory_strings.cpp | 110 +-------------- src/theory/strings/theory_strings.h | 9 +- .../strings/theory_strings_rewriter.cpp | 11 -- src/theory/strings/theory_strings_rewriter.h | 2 - src/theory/strings/theory_strings_utils.cpp | 12 ++ src/theory/strings/theory_strings_utils.h | 3 + src/theory/strings/type_enumerator.h | 11 +- 11 files changed, 170 insertions(+), 134 deletions(-) diff --git a/src/theory/strings/base_solver.cpp b/src/theory/strings/base_solver.cpp index 2f5bc8e2b..343f6e4f8 100644 --- a/src/theory/strings/base_solver.cpp +++ b/src/theory/strings/base_solver.cpp @@ -35,6 +35,7 @@ BaseSolver::BaseSolver(context::Context* c, { d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String("")); d_false = NodeManager::currentNM()->mkConst(false); + d_cardSize = utils::getAlphabetCardinality(); } BaseSolver::~BaseSolver() {} @@ -359,6 +360,138 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti, } } +void BaseSolver::checkCardinality() +{ + // This will create a partition of eqc, where each collection has length that + // are pairwise propagated to be equal. We do not require disequalities + // between the lengths of each collection, since we split on disequalities + // between lengths of string terms that are disequal (DEQ-LENGTH-SP). + std::vector > cols; + std::vector lts; + d_state.separateByLength(d_stringsEqc, cols, lts); + NodeManager* nm = NodeManager::currentNM(); + Trace("strings-card") << "Check cardinality...." << std::endl; + // for each collection + for (unsigned i = 0, csize = cols.size(); i < csize; ++i) + { + Node lr = lts[i]; + Trace("strings-card") << "Number of strings with length equal to " << lr + << " is " << cols[i].size() << std::endl; + if (cols[i].size() <= 1) + { + // no restriction on sets in the partition of size 1 + continue; + } + // size > c^k + unsigned card_need = 1; + double curr = static_cast(cols[i].size()); + while (curr > d_cardSize) + { + curr = curr / static_cast(d_cardSize); + card_need++; + } + Trace("strings-card") + << "Need length " << card_need + << " for this number of strings (where alphabet size is " << d_cardSize + << ")." << std::endl; + // check if we need to split + bool needsSplit = true; + if (lr.isConst()) + { + // if constant, compare + Node cmp = nm->mkNode(GEQ, lr, nm->mkConst(Rational(card_need))); + cmp = Rewriter::rewrite(cmp); + needsSplit = !cmp.getConst(); + } + else + { + // find the minimimum constant that we are unknown to be disequal from, or + // otherwise stop if we increment such that cardinality does not apply + unsigned r = 0; + bool success = true; + while (r < card_need && success) + { + Node rr = nm->mkConst(Rational(r)); + if (d_state.areDisequal(rr, lr)) + { + r++; + } + else + { + success = false; + } + } + if (r > 0) + { + Trace("strings-card") + << "Symbolic length " << lr << " must be at least " << r + << " due to constant disequalities." << std::endl; + } + needsSplit = r < card_need; + } + + if (!needsSplit) + { + // don't need to split + continue; + } + // first, try to split to merge equivalence classes + for (std::vector::iterator itr1 = cols[i].begin(); + itr1 != cols[i].end(); + ++itr1) + { + for (std::vector::iterator itr2 = itr1 + 1; itr2 != cols[i].end(); + ++itr2) + { + if (!d_state.areDisequal(*itr1, *itr2)) + { + // add split lemma + if (d_im.sendSplit(*itr1, *itr2, "CARD-SP")) + { + return; + } + } + } + } + // otherwise, we need a length constraint + uint32_t int_k = static_cast(card_need); + EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true); + Trace("strings-card") << "Previous cardinality used for " << lr << " is " + << ((int)ei->d_cardinalityLemK.get() - 1) + << std::endl; + if (int_k + 1 > ei->d_cardinalityLemK.get()) + { + Node k_node = nm->mkConst(Rational(int_k)); + // add cardinality lemma + Node dist = nm->mkNode(DISTINCT, cols[i]); + std::vector vec_node; + vec_node.push_back(dist); + for (std::vector::iterator itr1 = cols[i].begin(); + itr1 != cols[i].end(); + ++itr1) + { + Node len = nm->mkNode(STRING_LENGTH, *itr1); + if (len != lr) + { + Node len_eq_lr = len.eqNode(lr); + vec_node.push_back(len_eq_lr); + } + } + Node len = nm->mkNode(STRING_LENGTH, cols[i][0]); + Node cons = nm->mkNode(GEQ, len, k_node); + cons = Rewriter::rewrite(cons); + ei->d_cardinalityLemK.set(int_k + 1); + if (!cons.isConst() || !cons.getConst()) + { + std::vector emptyVec; + d_im.sendInference(emptyVec, vec_node, cons, "CARDINALITY", true); + return; + } + } + } + Trace("strings-card") << "...end check cardinality" << std::endl; +} + bool BaseSolver::isCongruent(Node n) { return d_congruent.find(n) != d_congruent.end(); diff --git a/src/theory/strings/base_solver.h b/src/theory/strings/base_solver.h index c87a3af9e..3681b49a4 100644 --- a/src/theory/strings/base_solver.h +++ b/src/theory/strings/base_solver.h @@ -70,6 +70,13 @@ class BaseSolver * case, we infer the fact x ++ "c" ++ y = "acb". */ void checkConstantEquivalenceClasses(); + /** check cardinality + * + * This function checks whether a cardinality inference needs to be applied + * to a set of equivalence classes. For details, see Step 5 of the proof + * procedure from Liang et al, CAV 2014. + */ + void checkCardinality(); //-----------------------end inference steps //-----------------------query functions @@ -182,6 +189,8 @@ class BaseSolver std::vector d_stringsEqc; /** A term index for each function kind */ std::map d_termIndex; + /** the cardinality of the alphabet */ + uint32_t d_cardSize; }; /* class BaseSolver */ } // namespace strings diff --git a/src/theory/strings/regexp_operation.cpp b/src/theory/strings/regexp_operation.cpp index 8707593c7..1b2de0eb5 100644 --- a/src/theory/strings/regexp_operation.cpp +++ b/src/theory/strings/regexp_operation.cpp @@ -42,7 +42,7 @@ RegExpOpr::RegExpOpr() std::vector{})), d_sigma_star(NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, d_sigma)) { - d_lastchar = TheoryStringsRewriter::getAlphabetCardinality()-1; + d_lastchar = utils::getAlphabetCardinality() - 1; } RegExpOpr::~RegExpOpr() {} diff --git a/src/theory/strings/regexp_operation.h b/src/theory/strings/regexp_operation.h index c7464760d..91d5df744 100644 --- a/src/theory/strings/regexp_operation.h +++ b/src/theory/strings/regexp_operation.h @@ -59,7 +59,7 @@ class RegExpOpr { private: /** the code point of the last character in the alphabet we are using */ - unsigned d_lastchar; + uint32_t d_lastchar; Node d_emptyString; Node d_true; Node d_false; diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 8a3c32fd8..5be2f96d4 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -127,7 +127,7 @@ TheoryStrings::TheoryStrings(context::Context* c, d_true = NodeManager::currentNM()->mkConst( true ); d_false = NodeManager::currentNM()->mkConst( false ); - d_card_size = TheoryStringsRewriter::getAlphabetCardinality(); + d_cardSize = utils::getAlphabetCardinality(); } TheoryStrings::~TheoryStrings() { @@ -565,7 +565,7 @@ void TheoryStrings::preRegisterTerm(TNode n) { std::vector vec = n.getConst().getVec(); for (unsigned u : vec) { - if (u >= d_card_size) + if (u >= d_cardSize) { std::stringstream ss; ss << "Characters in string \"" << n @@ -1146,110 +1146,6 @@ void TheoryStrings::checkRegisterTermsNormalForms() } } -void TheoryStrings::checkCardinality() { - //int cardinality = options::stringCharCardinality(); - //Trace("strings-solve-debug2") << "get cardinality: " << cardinality << endl; - - //AJR: this will create a partition of eqc, where each collection has length that are pairwise propagated to be equal. - // we do not require disequalities between the lengths of each collection, since we split on disequalities between lengths of string terms that are disequal (DEQ-LENGTH-SP). - // TODO: revisit this? - const std::vector& seqc = d_bsolver.getStringEqc(); - std::vector< std::vector< Node > > cols; - std::vector< Node > lts; - d_state.separateByLength(seqc, cols, lts); - - Trace("strings-card") << "Check cardinality...." << std::endl; - for( unsigned i = 0; i 1 ) { - // size > c^k - unsigned card_need = 1; - double curr = (double)cols[i].size(); - while( curr>d_card_size ){ - curr = curr/(double)d_card_size; - card_need++; - } - Trace("strings-card") << "Need length " << card_need << " for this number of strings (where alphabet size is " << d_card_size << ")." << std::endl; - //check if we need to split - bool needsSplit = true; - if( lr.isConst() ){ - // if constant, compare - Node cmp = NodeManager::currentNM()->mkNode( kind::GEQ, lr, NodeManager::currentNM()->mkConst( Rational( card_need ) ) ); - cmp = Rewriter::rewrite( cmp ); - needsSplit = cmp!=d_true; - }else{ - // find the minimimum constant that we are unknown to be disequal from, or otherwise stop if we increment such that cardinality does not apply - unsigned r=0; - bool success = true; - while( rmkConst( Rational(r) ); - if (d_state.areDisequal(rr, lr)) - { - r++; - } - else - { - success = false; - } - } - if( r>0 ){ - Trace("strings-card") << "Symbolic length " << lr << " must be at least " << r << " due to constant disequalities." << std::endl; - } - needsSplit = r::iterator itr1 = cols[i].begin(); - itr1 != cols[i].end(); ++itr1) { - for( std::vector< Node >::iterator itr2 = itr1 + 1; - itr2 != cols[i].end(); ++itr2) { - if (!d_state.areDisequal(*itr1, *itr2)) - { - // add split lemma - if (d_im.sendSplit(*itr1, *itr2, "CARD-SP")) - { - return; - } - } - } - } - EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true); - Trace("strings-card") - << "Previous cardinality used for " << lr << " is " - << ((int)ei->d_cardinalityLemK.get() - 1) << std::endl; - if (int_k + 1 > ei->d_cardinalityLemK.get()) - { - Node k_node = NodeManager::currentNM()->mkConst( ::CVC4::Rational( int_k ) ); - //add cardinality lemma - Node dist = NodeManager::currentNM()->mkNode( kind::DISTINCT, cols[i] ); - std::vector< Node > vec_node; - vec_node.push_back( dist ); - for( std::vector< Node >::iterator itr1 = cols[i].begin(); - itr1 != cols[i].end(); ++itr1) { - Node len = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, *itr1 ); - if( len!=lr ) { - Node len_eq_lr = len.eqNode(lr); - vec_node.push_back( len_eq_lr ); - } - } - Node len = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, cols[i][0] ); - Node cons = NodeManager::currentNM()->mkNode( kind::GEQ, len, k_node ); - cons = Rewriter::rewrite( cons ); - ei->d_cardinalityLemK.set(int_k + 1); - if( cons!=d_true ){ - d_im.sendInference( - d_empty_vec, vec_node, cons, "CARDINALITY", true); - return; - } - } - } - } - } - Trace("strings-card") << "...end check cardinality" << std::endl; -} - Node TheoryStrings::ppRewrite(TNode atom) { Trace("strings-ppr") << "TheoryStrings::ppRewrite " << atom << std::endl; Node atomElim; @@ -1328,7 +1224,7 @@ void TheoryStrings::runInferStep(InferStep s, int effort) case CHECK_REGISTER_TERMS_NF: checkRegisterTermsNormalForms(); break; case CHECK_EXTF_REDUCTION: d_esolver->checkExtfReductions(effort); break; case CHECK_MEMBERSHIP: checkMemberships(); break; - case CHECK_CARDINALITY: checkCardinality(); break; + case CHECK_CARDINALITY: d_bsolver.checkCardinality(); break; default: Unreachable(); break; } Trace("strings-process") << "Done " << s diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 55852490f..f40af6e67 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -206,7 +206,7 @@ class TheoryStrings : public Theory { Node d_one; Node d_neg_one; /** the cardinality of the alphabet */ - unsigned d_card_size; + uint32_t d_cardSize; /** The notify class */ NotifyClass d_notify; /** Equaltity engine */ @@ -401,13 +401,6 @@ private: * FroCoS 2015. */ void checkMemberships(); - /** check cardinality - * - * This function checks whether a cardinality inference needs to be applied - * to a set of equivalence classes. For details, see Step 5 of the proof - * procedure from Liang et al, CAV 2014. - */ - void checkCardinality(); //-----------------------end inference steps //-----------------------representation of the strategy diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index e9a4ebfd1..339d11dd2 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -230,17 +230,6 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren, return Node::null(); } -unsigned TheoryStringsRewriter::getAlphabetCardinality() -{ - if (options::stdPrintASCII()) - { - Assert(128 <= String::num_codes()); - return 128; - } - Assert(256 <= String::num_codes()); - return 256; -} - Node TheoryStringsRewriter::rewriteEquality(Node node) { Assert(node.getKind() == kind::EQUAL); diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index 7d76234bc..c9733433c 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -159,8 +159,6 @@ class TheoryStringsRewriter : public TheoryRewriter RewriteResponse postRewrite(TNode node) override; RewriteResponse preRewrite(TNode node) override; - /** get the cardinality of the alphabet used, based on the options */ - static unsigned getAlphabetCardinality(); /** rewrite equality * * This method returns a formula that is equivalent to the equality between diff --git a/src/theory/strings/theory_strings_utils.cpp b/src/theory/strings/theory_strings_utils.cpp index a564c82e1..a325108e4 100644 --- a/src/theory/strings/theory_strings_utils.cpp +++ b/src/theory/strings/theory_strings_utils.cpp @@ -16,6 +16,7 @@ #include "theory/strings/theory_strings_utils.h" +#include "options/strings_options.h" #include "theory/rewriter.h" using namespace CVC4::kind; @@ -25,6 +26,17 @@ namespace theory { namespace strings { namespace utils { +uint32_t getAlphabetCardinality() +{ + if (options::stdPrintASCII()) + { + Assert(128 <= String::num_codes()); + return 128; + } + Assert(256 <= String::num_codes()); + return 256; +} + Node mkAnd(const std::vector& a) { std::vector au; diff --git a/src/theory/strings/theory_strings_utils.h b/src/theory/strings/theory_strings_utils.h index ccdac8edf..51fe8cfc7 100644 --- a/src/theory/strings/theory_strings_utils.h +++ b/src/theory/strings/theory_strings_utils.h @@ -28,6 +28,9 @@ namespace theory { namespace strings { namespace utils { +/** get the cardinality of the alphabet used, based on the options */ +uint32_t getAlphabetCardinality(); + /** * Make the conjunction of nodes in a. Removes duplicate conjuncts, returns * true if a is empty, and a single literal if a has size 1. diff --git a/src/theory/strings/type_enumerator.h b/src/theory/strings/type_enumerator.h index 4218d4ce5..0171effaf 100644 --- a/src/theory/strings/type_enumerator.h +++ b/src/theory/strings/type_enumerator.h @@ -24,6 +24,7 @@ #include "expr/kind.h" #include "expr/type_node.h" #include "theory/strings/theory_strings_rewriter.h" +#include "theory/strings/theory_strings_utils.h" #include "theory/type_enumerator.h" #include "util/regexp.h" @@ -33,7 +34,7 @@ namespace strings { class StringEnumerator : public TypeEnumeratorBase { std::vector< unsigned > d_data; - unsigned d_cardinality; + uint32_t d_cardinality; Node d_curr; void mkCurr() { //make constant from d_data @@ -46,7 +47,7 @@ class StringEnumerator : public TypeEnumeratorBase { { Assert(type.getKind() == kind::TYPE_CONSTANT && type.getConst() == STRING_TYPE); - d_cardinality = TheoryStringsRewriter::getAlphabetCardinality(); + d_cardinality = utils::getAlphabetCardinality(); mkCurr(); } Node operator*() override { return d_curr; } @@ -85,7 +86,7 @@ class StringEnumerator : public TypeEnumeratorBase { class StringEnumeratorLength { private: - unsigned d_cardinality; + uint32_t d_cardinality; std::vector< unsigned > d_data; Node d_curr; void mkCurr() { @@ -94,7 +95,9 @@ class StringEnumeratorLength { } public: - StringEnumeratorLength(unsigned length, unsigned card = 256) : d_cardinality(card) { + StringEnumeratorLength(uint32_t length, uint32_t card = 256) + : d_cardinality(card) + { for( unsigned i=0; i