From f2ef41253381cc4a5cd7ba1d62a9443535f8a3dd Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Wed, 19 Jan 2022 15:45:21 -0800 Subject: [PATCH] Add rewrites for `seq.update`/`seq.nth` (#7966) This commit adds rewrites for `seq.update` and `seq.nth`. It adds a rewrite to check if an update is out of bounds and it adds rewrites to evaluate these operators using `stripSymbolicLength()`. This allows us for example to evaluate updates on concatenations of non-constant `seq.unit`s. --- src/theory/strings/arith_entail.cpp | 7 +- src/theory/strings/rewrites.cpp | 4 +- src/theory/strings/rewrites.h | 4 +- src/theory/strings/sequences_rewriter.cpp | 44 ++++- test/unit/theory/sequences_rewriter_white.cpp | 157 ++++++++++++++++++ 5 files changed, 209 insertions(+), 7 deletions(-) diff --git a/src/theory/strings/arith_entail.cpp b/src/theory/strings/arith_entail.cpp index 19f1bc97d..0659a2ff9 100644 --- a/src/theory/strings/arith_entail.cpp +++ b/src/theory/strings/arith_entail.cpp @@ -861,12 +861,17 @@ Node ArithEntail::getConstantBoundLength(TNode s, bool isLower) const return ret; } NodeManager* nm = NodeManager::currentNM(); + Kind sk = s.getKind(); if (s.isConst()) { size_t len = Word::getLength(s); ret = nm->mkConstInt(Rational(len)); } - else if (s.getKind() == STRING_CONCAT) + else if (sk == SEQ_UNIT) + { + ret = nm->mkConstInt(1); + } + else if (sk == STRING_CONCAT) { Rational sum(0); bool success = true; diff --git a/src/theory/strings/rewrites.cpp b/src/theory/strings/rewrites.cpp index 262459b4f..98420faaa 100644 --- a/src/theory/strings/rewrites.cpp +++ b/src/theory/strings/rewrites.cpp @@ -151,11 +151,13 @@ const char* toString(Rewrite r) case Rewrite::SS_STRIP_END_PT: return "SS_STRIP_END_PT"; case Rewrite::SS_STRIP_START_PT: return "SS_STRIP_START_PT"; case Rewrite::UPD_EVAL: return "UPD_EVAL"; + case Rewrite::UPD_EVAL_SYM: return "UPD_EVAL_SYM"; case Rewrite::UPD_EMPTYSTR: return "UPD_EMPTYSTR"; case Rewrite::UPD_CONST_INDEX_MAX_OOB: return "UPD_CONST_INDEX_MAX_OOB"; case Rewrite::UPD_CONST_INDEX_NEG: return "UPD_CONST_INDEX_NEG"; case Rewrite::UPD_CONST_INDEX_OOB: return "UPD_CONST_INDEX_OOB"; case Rewrite::UPD_REV: return "UPD_REV"; + case Rewrite::UPD_OOB: return "UPD_OOB"; case Rewrite::STOI_CONCAT_NONNUM: return "STOI_CONCAT_NONNUM"; case Rewrite::STOI_EVAL: return "STOI_EVAL"; case Rewrite::STR_CONV_CONST: return "STR_CONV_CONST"; @@ -225,7 +227,7 @@ const char* toString(Rewrite r) case Rewrite::SEQ_UNIT_EVAL: return "SEQ_UNIT_EVAL"; case Rewrite::SEQ_NTH_EVAL: return "SEQ_NTH_EVAL"; case Rewrite::SEQ_NTH_TOTAL_OOB: return "SEQ_NTH_TOTAL_OOB"; - case Rewrite::SEQ_NTH_UNIT: return "SEQ_NTH_UNIT"; + case Rewrite::SEQ_NTH_EVAL_SYM: return "SEQ_NTH_EVAL_SYM"; default: return "?"; } } diff --git a/src/theory/strings/rewrites.h b/src/theory/strings/rewrites.h index eedd0fb40..de8af0b4b 100644 --- a/src/theory/strings/rewrites.h +++ b/src/theory/strings/rewrites.h @@ -152,11 +152,13 @@ enum class Rewrite : uint32_t SS_STRIP_END_PT, SS_STRIP_START_PT, UPD_EVAL, + UPD_EVAL_SYM, UPD_EMPTYSTR, UPD_CONST_INDEX_MAX_OOB, UPD_CONST_INDEX_NEG, UPD_CONST_INDEX_OOB, UPD_REV, + UPD_OOB, STOI_CONCAT_NONNUM, STOI_EVAL, STR_CONV_CONST, @@ -226,7 +228,7 @@ enum class Rewrite : uint32_t SEQ_UNIT_EVAL, SEQ_NTH_EVAL, SEQ_NTH_TOTAL_OOB, - SEQ_NTH_UNIT + SEQ_NTH_EVAL_SYM }; /** diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index 374b43c34..e4db2f325 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -1765,10 +1765,18 @@ Node SequencesRewriter::rewriteSeqNth(Node node) } } - if (s.getKind() == SEQ_UNIT && i.isConst() && i.getConst().isZero()) + std::vector prefix, suffix; + utils::getConcat(s, suffix); + if ((i.isConst() && i.getConst().isZero()) + || d_stringsEntail.stripSymbolicLength(suffix, prefix, 1, i, true)) { - Node ret = s[0]; - return returnRewrite(node, ret, Rewrite::SEQ_NTH_UNIT); + if (suffix.size() > 0 && suffix[0].getKind() == SEQ_UNIT) + { + // (seq.nth (seq.++ prefix (seq.unit x) suffix) n) ---> x + // if len(prefix) = n + Node ret = suffix[0][0]; + return returnRewrite(node, ret, Rewrite::SEQ_NTH_EVAL_SYM); + } } return node; @@ -2088,11 +2096,39 @@ Node SequencesRewriter::rewriteUpdate(Node node) } } + NodeManager* nm = NodeManager::currentNM(); + Node zero = nm->mkConstInt(0); + Node sLen = nm->mkNode(STRING_LENGTH, s); + if (d_arithEntail.check(zero, i, true) || d_arithEntail.check(i, sLen)) + { + // (seq.update s i x) ---> s if x < 0 or x >= len(s) + Node ret = s; + return returnRewrite(node, ret, Rewrite::UPD_OOB); + } + + std::vector prefix, suffix; + utils::getConcat(s, suffix); + if ((i.isConst() && i.getConst().isZero()) + || d_stringsEntail.stripSymbolicLength(suffix, prefix, 1, i, true)) + { + Node updateLen = nm->mkNode(STRING_LENGTH, x); + std::vector replaced; + if (d_stringsEntail.stripSymbolicLength( + suffix, replaced, 1, updateLen, true)) + { + // (seq.update (seq.++ p r s) i x) ---> (seq.++ p x s) + // if len(p) = i and len(r) = len(x) + prefix.emplace_back(x); + prefix.insert(prefix.end(), suffix.begin(), suffix.end()); + Node ret = utils::mkConcat(prefix, node.getType()); + return returnRewrite(node, ret, Rewrite::UPD_EVAL_SYM); + } + } + if (s.getKind() == STRING_REV && d_stringsEntail.checkLengthOne(x)) { // str.update(str.rev(s), n, t) ---> // str.rev(str.update(s, len(s) - (n + 1), t)) - NodeManager* nm = NodeManager::currentNM(); Node idx = nm->mkNode(MINUS, nm->mkNode(STRING_LENGTH, s), nm->mkNode(PLUS, i, nm->mkConstInt(Rational(1)))); diff --git a/test/unit/theory/sequences_rewriter_white.cpp b/test/unit/theory/sequences_rewriter_white.cpp index c468aa463..8ce982291 100644 --- a/test/unit/theory/sequences_rewriter_white.cpp +++ b/test/unit/theory/sequences_rewriter_white.cpp @@ -216,6 +216,58 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) false)); } +TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_nth) +{ + TypeNode intType = d_nodeManager->integerType(); + + Node x = d_nodeManager->mkVar("x", intType); + Node y = d_nodeManager->mkVar("y", intType); + Node z = d_nodeManager->mkVar("z", intType); + Node w = d_nodeManager->mkVar("w", intType); + Node v = d_nodeManager->mkVar("v", intType); + + Node zero = d_nodeManager->mkConstInt(0); + Node one = d_nodeManager->mkConstInt(1); + + Node sx = d_nodeManager->mkNode(SEQ_UNIT, x); + Node sy = d_nodeManager->mkNode(SEQ_UNIT, y); + Node sz = d_nodeManager->mkNode(SEQ_UNIT, z); + Node sw = d_nodeManager->mkNode(SEQ_UNIT, w); + Node sv = d_nodeManager->mkNode(SEQ_UNIT, v); + Node xyz = d_nodeManager->mkNode(STRING_CONCAT, sx, sy, sz); + Node wv = d_nodeManager->mkNode(STRING_CONCAT, sw, sv); + + { + // Same normal form for: + // + // (seq.nth (seq.unit x) 0) + // + // x + Node n = d_nodeManager->mkNode(SEQ_NTH, sx, zero); + sameNormalForm(n, x); + } + + { + // Same normal form for: + // + // (seq.nth (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) 0) + // + // x + Node n = d_nodeManager->mkNode(SEQ_NTH, xyz, zero); + sameNormalForm(n, x); + } + + { + // Same normal form for: + // + // (seq.nth (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) 0) + // + // x + Node n = d_nodeManager->mkNode(SEQ_NTH, xyz, one); + sameNormalForm(n, y); + } +} + TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) { StringsRewriter sr(d_rewriter, nullptr); @@ -356,6 +408,111 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) sameNormalForm(substr, empty); } +TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_update) +{ + TypeNode intType = d_nodeManager->integerType(); + + Node x = d_nodeManager->mkVar("x", intType); + Node y = d_nodeManager->mkVar("y", intType); + Node z = d_nodeManager->mkVar("z", intType); + Node w = d_nodeManager->mkVar("w", intType); + Node v = d_nodeManager->mkVar("v", intType); + + Node negOne = d_nodeManager->mkConstInt(-1); + Node zero = d_nodeManager->mkConstInt(0); + Node one = d_nodeManager->mkConstInt(1); + Node three = d_nodeManager->mkConstInt(3); + + Node sx = d_nodeManager->mkNode(SEQ_UNIT, x); + Node sy = d_nodeManager->mkNode(SEQ_UNIT, y); + Node sz = d_nodeManager->mkNode(SEQ_UNIT, z); + Node sw = d_nodeManager->mkNode(SEQ_UNIT, w); + Node sv = d_nodeManager->mkNode(SEQ_UNIT, v); + Node xyz = d_nodeManager->mkNode(STRING_CONCAT, sx, sy, sz); + Node wv = d_nodeManager->mkNode(STRING_CONCAT, sw, sv); + + { + // Same normal form for: + // + // (seq.update + // (seq.unit x)) + // 0 + // (seq.unit w)) + // + // (seq.unit w) + Node n = d_nodeManager->mkNode(STRING_UPDATE, sx, zero, sw); + sameNormalForm(n, sw); + } + + { + // Same normal form for: + // + // (seq.update + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + // 0 + // (seq.unit w)) + // + // (seq.++ (seq.unit w) (seq.unit y) (seq.unit z)) + Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, zero, sw); + Node wyz = d_nodeManager->mkNode(STRING_CONCAT, sw, sy, sz); + sameNormalForm(n, wyz); + } + + { + // Same normal form for: + // + // (seq.update + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + // 1 + // (seq.unit w)) + // + // (seq.++ (seq.unit x) (seq.unit w) (seq.unit z)) + Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, one, sw); + Node xwz = d_nodeManager->mkNode(STRING_CONCAT, sx, sw, sz); + sameNormalForm(n, xwz); + } + + { + // Same normal form for: + // + // (seq.update + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + // 1 + // (seq.++ (seq.unit w) (seq.unit v))) + // + // (seq.++ (seq.unit x) (seq.unit w) (seq.unit v)) + Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, one, wv); + Node xwv = d_nodeManager->mkNode(STRING_CONCAT, sx, sw, sv); + sameNormalForm(n, xwv); + } + + { + // Same normal form for: + // + // (seq.update + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + // -1 + // (seq.++ (seq.unit w) (seq.unit v))) + // + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, negOne, wv); + sameNormalForm(n, xyz); + } + + { + // Same normal form for: + // + // (seq.update + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + // 3 + // w) + // + // (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) + Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, three, sw); + sameNormalForm(n, xyz); + } +} + TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_concat) { TypeNode intType = d_nodeManager->integerType(); -- 2.30.2