From d6fd1eff6025be9f0d8d5e7dcb02bffdda931828 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 11 Nov 2021 10:24:38 -0600 Subject: [PATCH] Fixes for update/nth over constant sequences (#7631) This includes fixes related to how we process seq.update / seq.nth applied to sequence constants. --- src/theory/strings/array_solver.cpp | 94 +++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 18 deletions(-) diff --git a/src/theory/strings/array_solver.cpp b/src/theory/strings/array_solver.cpp index 09e3aefdd..b3e2aa457 100644 --- a/src/theory/strings/array_solver.cpp +++ b/src/theory/strings/array_solver.cpp @@ -15,8 +15,10 @@ #include "theory/strings/array_solver.h" +#include "expr/sequence.h" #include "theory/strings/arith_entail.h" #include "theory/strings/theory_strings_utils.h" +#include "theory/strings/word.h" #include "util/rational.h" using namespace cvc5::context; @@ -81,6 +83,7 @@ void ArraySolver::checkTerms(Kind k) Node r = d_state.getRepresentative(t[0]); NormalForm& nf = d_csolver.getNormalForm(r); Trace("seq-array-debug") << "...normal form " << nf.d_nf << std::endl; + std::vector nfChildren; if (nf.d_nf.empty()) { // updates should have been reduced (UPD_EMPTYSTR) @@ -92,8 +95,16 @@ void ArraySolver::checkTerms(Kind k) { Trace("seq-array-debug") << "...norm form size 1" << std::endl; // NOTE: could split on n=0 if needed, do not introduce ITE - if (nf.d_nf[0].getKind() == SEQ_UNIT) + Kind ck = nf.d_nf[0].getKind(); + // Note that (seq.unit c) is rewritten to CONST_SEQUENCE{c}, hence we + // check two cases here. It is important for completeness of this schema + // to handle this differently from STRINGS_ARRAY_UPDATE_CONCAT / + // STRINGS_ARRAY_NTH_CONCAT. Otherwise we would conclude a trivial + // equality when update/nth is applied to a constant of length one. + if (ck == SEQ_UNIT + || (ck == CONST_SEQUENCE && Word::getLength(nf.d_nf[0]) == 1)) { + Trace("seq-array-debug") << "...unit case" << std::endl; // do we know whether n = 0 ? // x = (seq.unit m) => (seq.update x n z) = ite(n=0, z, (seq.unit m)) // x = (seq.unit m) => (seq.nth x n) = ite(n=0, m, Uf(x, n)) @@ -109,7 +120,15 @@ void ArraySolver::checkTerms(Kind k) else { Assert(k == SEQ_NTH); - thenBranch = nf.d_nf[0][0]; + if (ck == CONST_SEQUENCE) + { + const Sequence& seq = nf.d_nf[0].getConst(); + thenBranch = seq.getVec()[0]; + } + else + { + thenBranch = nf.d_nf[0][0]; + } Node uf = SkolemCache::mkSkolemSeqNth(t[0].getType(), "Uf"); elseBranch = nm->mkNode(APPLY_UF, uf, t[0], t[1]); iid = InferenceId::STRINGS_ARRAY_NTH_UNIT; @@ -126,17 +145,33 @@ void ArraySolver::checkTerms(Kind k) d_eqProc.insert(eq); d_im.sendInference(exp, eq, iid); } + continue; } - // otherwise, the equivalence class is pure wrt concatenation - d_currTerms[k].push_back(t); - continue; + else if (ck != CONST_SEQUENCE) + { + // otherwise, if the normal form is not a constant sequence, the + // equivalence class is pure wrt concatenation. + d_currTerms[k].push_back(t); + continue; + } + // if the normal form is a constant sequence, it is treated as a + // concatenation. We split per character and case split on whether the + // nth/update falls on each character below, which must have a size + // greater than one. + std::vector chars = Word::getChars(nf.d_nf[0]); + Assert (chars.size()>1); + nfChildren.insert(nfChildren.end(), chars.begin(), chars.end()); + } + else + { + nfChildren.insert(nfChildren.end(), nf.d_nf.begin(), nf.d_nf.end()); } // otherwise, we are the concatenation of the components // NOTE: for nth, split on index vs component lengths, do not introduce ITE std::vector cond; std::vector cchildren; std::vector lacc; - for (const Node& c : nf.d_nf) + for (const Node& c : nfChildren) { Trace("seq-array-debug") << "...process " << c << std::endl; Node clen = nm->mkNode(STRING_LENGTH, c); @@ -146,26 +181,49 @@ void ArraySolver::checkTerms(Kind k) Node currSum = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc); currIndex = nm->mkNode(MINUS, currIndex, currSum); } - if (k == STRING_UPDATE) + Node cc; + // If it is a constant of length one, then the update/nth is determined + // in this interval. Notice this is done here as + // an optimization to short cut introducing terms like + // (seq.nth (seq.unit c) i), which by construction is only relevant in + // the context where i = 0, hence we replace by c here. + if (c.getKind() == CONST_SEQUENCE) { - Node cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]); - Trace("seq-array-debug") << "......component " << cc << std::endl; - cchildren.push_back(cc); + const Sequence& seq = c.getConst(); + if (seq.size() == 1) + { + if (k == STRING_UPDATE) + { + cc = nm->mkNode(ITE, t[1].eqNode(d_zero), t[2], c); + } + else + { + cc = seq.getVec()[0]; + } + } } - else + // if we did not process as a constant of length one + if (cc.isNull()) { - Assert(k == SEQ_NTH); - Node cc = nm->mkNode(SEQ_NTH, c, currIndex); - Trace("seq-array-debug") << "......component " << cc << std::endl; - cchildren.push_back(cc); + if (k == STRING_UPDATE) + { + cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]); + } + else + { + Assert(k == SEQ_NTH); + cc = nm->mkNode(SEQ_NTH, c, currIndex); + } } + Trace("seq-array-debug") << "......component " << cc << std::endl; + cchildren.push_back(cc); lacc.push_back(clen); if (k == SEQ_NTH) { Node currSumPost = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc); - Node cc = nm->mkNode(LT, t[1], currSumPost); - Trace("seq-array-debug") << "......condition " << cc << std::endl; - cond.push_back(cc); + Node cf = nm->mkNode(LT, t[1], currSumPost); + Trace("seq-array-debug") << "......condition " << cf << std::endl; + cond.push_back(cf); } } // z = (seq.++ x y) => -- 2.30.2