From effb0d47ba5bfaebae17dcd06153489dccd90eff Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 14 Dec 2021 13:35:09 -0600 Subject: [PATCH] Connecting the core array solver in strings (#7800) This PR takes most of the remaining changes from the seqArray branch apart from the extension to model construction. Notably it connects the core array solver to the array solver in strings. --- src/expr/skolem_manager.cpp | 1 + src/expr/skolem_manager.h | 2 ++ src/smt/proof_post_processor.cpp | 1 + src/theory/strings/array_core_solver.cpp | 8 ++--- src/theory/strings/array_core_solver.h | 25 +++++++++++++-- src/theory/strings/array_solver.cpp | 37 +++++++++++++++++++++++ src/theory/strings/array_solver.h | 29 ++++++++++++++++++ src/theory/strings/rewrites.cpp | 2 ++ src/theory/strings/rewrites.h | 4 ++- src/theory/strings/sequences_rewriter.cpp | 24 +++++++++++---- 10 files changed, 119 insertions(+), 14 deletions(-) diff --git a/src/expr/skolem_manager.cpp b/src/expr/skolem_manager.cpp index 206ebb9ce..f08ffc5f4 100644 --- a/src/expr/skolem_manager.cpp +++ b/src/expr/skolem_manager.cpp @@ -67,6 +67,7 @@ const char* toString(SkolemFunId id) case SkolemFunId::SK_FIRST_MATCH: return "SK_FIRST_MATCH"; case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST"; case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT"; + case SkolemFunId::SEQ_MODEL_BASE_ELEMENT: return "SEQ_MODEL_BASE_ELEMENT"; case SkolemFunId::BAGS_CARD_CARDINALITY: return "BAGS_CARD_CARDINALITY"; case SkolemFunId::BAGS_CARD_ELEMENTS: return "BAGS_CARD_ELEMENTS"; case SkolemFunId::BAGS_CARD_N: return "BAGS_CARD_N"; diff --git a/src/expr/skolem_manager.h b/src/expr/skolem_manager.h index 93b26b6cb..cca28ccf0 100644 --- a/src/expr/skolem_manager.h +++ b/src/expr/skolem_manager.h @@ -112,6 +112,8 @@ enum class SkolemFunId * i = 0, ..., n. */ RE_UNFOLD_POS_COMPONENT, + /** Sequence model construction, element for base */ + SEQ_MODEL_BASE_ELEMENT, BAGS_CARD_CARDINALITY, BAGS_CARD_ELEMENTS, BAGS_CARD_N, diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 90f0a48bf..167a82e26 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -422,6 +422,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, // not eliminated return Node::null(); } + Trace("smt-proof-pp-debug") << "Expand macro " << id << std::endl; // macro elimination if (id == PfRule::MACRO_SR_EQ_INTRO) { diff --git a/src/theory/strings/array_core_solver.cpp b/src/theory/strings/array_core_solver.cpp index ed3690068..3b8fdeff4 100644 --- a/src/theory/strings/array_core_solver.cpp +++ b/src/theory/strings/array_core_solver.cpp @@ -70,7 +70,7 @@ void ArrayCoreSolver::checkNth(const std::vector& nthTerms) // (seq.extract A i l) ^ (<= 0 i) ^ (< i (str.len A)) --> (seq.unit // (seq.nth A i)) std::vector exp; - Node cond1 = nm->mkNode(LEQ, nm->mkConst(Rational(0)), n[1]); + Node cond1 = nm->mkNode(LEQ, nm->mkConstInt(Rational(0)), n[1]); Node cond2 = nm->mkNode(LT, n[1], nm->mkNode(STRING_LENGTH, n[0])); Node cond = nm->mkNode(AND, cond1, cond2); Node body1 = nm->mkNode( @@ -115,7 +115,7 @@ void ArrayCoreSolver::checkUpdate(const std::vector& updateTerms) // n[2][0] Node left = nm->mkNode(SEQ_NTH, termProxy, n[1]); Node right = - nm->mkNode(SEQ_NTH, n[2], nm->mkConst(Rational(0))); // n[2][0] + nm->mkNode(SEQ_NTH, n[2], nm->mkConstInt(Rational(0))); // n[2][0] right = Rewriter::rewrite(right); Node lem = nm->mkNode(EQUAL, left, right); Trace("seq-array-debug") << "enter" << std::endl; @@ -211,10 +211,10 @@ void ArrayCoreSolver::check(const std::vector& nthTerms, Node i = n[1]; Node sLen = nm->mkNode(STRING_LENGTH, s); Node iRev = nm->mkNode( - MINUS, sLen, nm->mkNode(PLUS, i, nm->mkConst(Rational(1)))); + MINUS, sLen, nm->mkNode(PLUS, i, nm->mkConstInt(Rational(1)))); std::vector nexp; - nexp.push_back(nm->mkNode(LEQ, nm->mkConst(Rational(0)), i)); + nexp.push_back(nm->mkNode(LEQ, nm->mkConstInt(Rational(0)), i)); nexp.push_back(nm->mkNode(LT, i, sLen)); // 0 <= i ^ i < len(s) => seq.nth(seq.rev(s), i) = seq.nth(s, len(s) - i - diff --git a/src/theory/strings/array_core_solver.h b/src/theory/strings/array_core_solver.h index 7101da625..3873f6a69 100644 --- a/src/theory/strings/array_core_solver.h +++ b/src/theory/strings/array_core_solver.h @@ -62,7 +62,7 @@ class ArrayCoreSolver : protected EnvObj const std::map& getWriteModel(Node eqc); /** - * Get connected sequences + * Get connected sequences, see documentation of computeConnected. * @return a map M such that sequence equivalence class representatives x and * y are connected if an only if M[x] = M[y]. */ @@ -93,7 +93,18 @@ class ArrayCoreSolver : protected EnvObj */ void checkUpdate(const std::vector& updateTerms); - // TODO: document + /** + * Given the current set of update terms, this computes the connected + * sequences implied by the current equality information + this set of terms. + * Connected sequences is a reflexive transitive relation where additionally + * a and b are connected if there exists an update term (seq.update a n x) + * that is currently equal to b. + * + * This method runs a union find algorithm to compute all connected sequences. + * + * As a result of running this method, the map d_connectedSeq is populated + * with information regarding which sequences are connected. + */ void computeConnected(const std::vector& updateTerms); /** The solver state object */ @@ -110,7 +121,15 @@ class ArrayCoreSolver : protected EnvObj ExtTheory& d_extt; /** The write model */ std::map> d_writeModel; - /** Connected */ + /** + * Map from sequences to their "connected representative". Two sequences are + * connected (based on the definition described in computeConnected) iff they + * have the same connected representative. Sequences that do not occur in + * this map are assumed to be their own connected representative. + * + * This map is only valid after running computeConnected, and is valid + * only during model building. + */ std::map d_connectedSeq; /** The set of lemmas been sent */ context::CDHashSet d_lem; diff --git a/src/theory/strings/array_solver.cpp b/src/theory/strings/array_solver.cpp index c04bfe918..672ca8b76 100644 --- a/src/theory/strings/array_solver.cpp +++ b/src/theory/strings/array_solver.cpp @@ -41,6 +41,7 @@ ArraySolver::ArraySolver(Env& env, d_termReg(tr), d_csolver(cs), d_esolver(es), + d_coreSolver(env, s, im, tr, cs, es, extt), d_eqProc(context()) { NodeManager* nm = NodeManager::currentNM(); @@ -63,6 +64,32 @@ void ArraySolver::checkArrayConcat() checkTerms(SEQ_NTH); } +void ArraySolver::checkArray() +{ + if (!d_termReg.hasSeqUpdate()) + { + Trace("seq-array") << "No seq.update/seq.nth terms, skipping check..." + << std::endl; + return; + } + Trace("seq-array") << "ArraySolver::checkArray..." << std::endl; + d_coreSolver.check(d_currTerms[SEQ_NTH], d_currTerms[STRING_UPDATE]); +} + +void ArraySolver::checkArrayEager() +{ + if (!d_termReg.hasSeqUpdate()) + { + Trace("seq-array") << "No seq.update/seq.nth terms, skipping check..." + << std::endl; + return; + } + Trace("seq-array") << "ArraySolver::checkArray..." << std::endl; + std::vector nthTerms = d_esolver.getActive(SEQ_NTH); + std::vector updateTerms = d_esolver.getActive(STRING_UPDATE); + d_coreSolver.check(nthTerms, updateTerms); +} + void ArraySolver::checkTerms(Kind k) { Assert(k == STRING_UPDATE || k == SEQ_NTH); @@ -271,6 +298,16 @@ void ArraySolver::checkTerms(Kind k) } } +const std::map& ArraySolver::getWriteModel(Node eqc) +{ + return d_coreSolver.getWriteModel(eqc); +} + +const std::map& ArraySolver::getConnectedSequences() +{ + return d_coreSolver.getConnectedSequences(); +} + } // namespace strings } // namespace theory } // namespace cvc5 diff --git a/src/theory/strings/array_solver.h b/src/theory/strings/array_solver.h index 941061e9e..23bacd118 100644 --- a/src/theory/strings/array_solver.h +++ b/src/theory/strings/array_solver.h @@ -19,6 +19,7 @@ #define CVC5__THEORY__STRINGS__ARRAY_SOLVER_H #include "context/cdhashset.h" +#include "theory/strings/array_core_solver.h" #include "theory/strings/core_solver.h" #include "theory/strings/extf_solver.h" #include "theory/strings/inference_manager.h" @@ -54,6 +55,32 @@ class ArraySolver : protected EnvObj * their application to concatenation terms. */ void checkArrayConcat(); + /** + * Perform reasoning about seq.nth and seq.update operations (lazily), which + * calls the core sequences-array solver for the set of nth/update terms over atomic + * equivalence classes. + */ + void checkArray(); + /** + * Same as `checkArray`, but called eagerly, and for all nth/update terms, not just + * those over atomic equivalence classes. + */ + void checkArrayEager(); + + /** + * @param eqc The sequence equivalence class representative. We can assume + * the equivalence class of eqc contains no concatenation terms. + * @return the map corresponding to the model for eqc. The domain of + * the returned map should be in distinct integer equivalence classes of the + * equality engine of strings theory. The model assigned to eqc will be + * a skeleton constructed via seq.++ where the components take values from + * this map. + */ + const std::map& getWriteModel(Node eqc); + /** + * Get connected sequences from the core array solver. + */ + const std::map& getConnectedSequences(); private: /** check terms of given kind */ @@ -72,6 +99,8 @@ class ArraySolver : protected EnvObj std::map > d_currTerms; /** Common constants */ Node d_zero; + /** The core array solver */ + ArrayCoreSolver d_coreSolver; /** Equalities we have processed in the current context */ NodeSet d_eqProc; }; diff --git a/src/theory/strings/rewrites.cpp b/src/theory/strings/rewrites.cpp index 4da6e5600..bfe9021aa 100644 --- a/src/theory/strings/rewrites.cpp +++ b/src/theory/strings/rewrites.cpp @@ -154,6 +154,7 @@ const char* toString(Rewrite r) case Rewrite::UPD_CONST_INDEX_MAX_OOB: return "UPD_CONST_INDEX_MAX_OOB"; case Rewrite::UPD_CONST_INDEX_NEG: return "UPD_CONST_INDEX_NEG"; case Rewrite::UPD_CONST_INDEX_OOB: return "UPD_CONST_INDEX_OOB"; + case Rewrite::UPD_REV: return "UPD_REV"; case Rewrite::STOI_CONCAT_NONNUM: return "STOI_CONCAT_NONNUM"; case Rewrite::STOI_EVAL: return "STOI_EVAL"; case Rewrite::STR_CONV_CONST: return "STR_CONV_CONST"; @@ -223,6 +224,7 @@ const char* toString(Rewrite r) case Rewrite::SEQ_UNIT_EVAL: return "SEQ_UNIT_EVAL"; case Rewrite::SEQ_NTH_EVAL: return "SEQ_NTH_EVAL"; case Rewrite::SEQ_NTH_TOTAL_OOB: return "SEQ_NTH_TOTAL_OOB"; + case Rewrite::SEQ_NTH_UNIT: return "SEQ_NTH_UNIT"; default: return "?"; } } diff --git a/src/theory/strings/rewrites.h b/src/theory/strings/rewrites.h index c96dffcde..b57c5f276 100644 --- a/src/theory/strings/rewrites.h +++ b/src/theory/strings/rewrites.h @@ -155,6 +155,7 @@ enum class Rewrite : uint32_t UPD_CONST_INDEX_MAX_OOB, UPD_CONST_INDEX_NEG, UPD_CONST_INDEX_OOB, + UPD_REV, STOI_CONCAT_NONNUM, STOI_EVAL, STR_CONV_CONST, @@ -223,7 +224,8 @@ enum class Rewrite : uint32_t CHARAT_ELIM, SEQ_UNIT_EVAL, SEQ_NTH_EVAL, - SEQ_NTH_TOTAL_OOB + SEQ_NTH_TOTAL_OOB, + SEQ_NTH_UNIT }; /** diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index 7670c0b70..1ccb67490 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -1759,15 +1759,15 @@ Node SequencesRewriter::rewriteSeqNth(Node node) Node ret = nm->mkGroundValue(s.getType().getSequenceElementType()); return returnRewrite(node, ret, Rewrite::SEQ_NTH_TOTAL_OOB); } - else - { - return node; - } } - else + + if (s.getKind() == SEQ_UNIT && i.isConst() && i.getConst().isZero()) { - return node; + Node ret = s[0]; + return returnRewrite(node, ret, Rewrite::SEQ_NTH_UNIT); } + + return node; } Node SequencesRewriter::rewriteCharAt(Node node) @@ -2045,6 +2045,8 @@ Node SequencesRewriter::rewriteUpdate(Node node) { Assert(node.getKind() == kind::STRING_UPDATE); Node s = node[0]; + Node i = node[1]; + Node x = node[2]; if (s.isConst()) { if (Word::isEmpty(s)) @@ -2082,6 +2084,16 @@ Node SequencesRewriter::rewriteUpdate(Node node) } } + if (s.getKind() == STRING_REV) + { + NodeManager* nm = NodeManager::currentNM(); + Node idx = nm->mkNode(MINUS, + nm->mkNode(STRING_LENGTH, s), + nm->mkNode(PLUS, i, nm->mkConst(Rational(1)))); + Node ret = nm->mkNode(STRING_REV, nm->mkNode(STRING_UPDATE, s, idx, x)); + return returnRewrite(node, ret, Rewrite::UPD_REV); + } + return node; } -- 2.30.2