#include "theory/datatypes/datatypes_rewriter.h"
#include "theory/quantifiers/term_util.h"
#include "theory/rewriter.h"
+#include "theory/strings/theory_strings_rewriter.h"
using namespace CVC4::kind;
using namespace std;
{
new_ret = extendedRewriteArith(ret);
}
+ else if (tid == THEORY_STRINGS)
+ {
+ new_ret = extendedRewriteStrings(ret);
+ }
}
//----------------------end theory-specific post-rewriting
return new_ret;
}
+Node ExtendedRewriter::extendedRewriteStrings(Node ret)
+{
+ Node new_ret;
+ Trace("q-ext-rewrite-debug")
+ << "Extended rewrite strings : " << ret << std::endl;
+ NodeManager* nm = NodeManager::currentNM();
+ if (ret.getKind() == EQUAL)
+ {
+ if (ret[0].getType().isString())
+ {
+ Node tcontains[2];
+ bool tcontainsOneTrue = false;
+ unsigned tcontainsTrueIndex = 0;
+ for (unsigned i = 0; i < 2; i++)
+ {
+ Node tc = nm->mkNode(STRING_STRCTN, ret[i], ret[1 - i]);
+ tcontains[i] = Rewriter::rewrite(tc);
+ if (tcontains[i].isConst())
+ {
+ if (tcontains[i].getConst<bool>())
+ {
+ tcontainsOneTrue = true;
+ tcontainsTrueIndex = i;
+ }
+ else
+ {
+ new_ret = tcontains[i];
+ // if str.contains( x, y ) ---> false then x = y ---> false
+ // Notice we may not catch this in the rewriter for strings
+ // equality, since it only calls the specific rewriter for
+ // contains and not the full rewriter.
+ debugExtendedRewrite(ret, new_ret, "eq-contains-one-false");
+ return new_ret;
+ }
+ }
+ }
+ if (tcontainsOneTrue)
+ {
+ // if str.contains( x, y ) ---> true
+ // then x = y ---> contains( y, x )
+ new_ret = tcontains[1 - tcontainsTrueIndex];
+ debugExtendedRewrite(ret, new_ret, "eq-contains-one-true");
+ return new_ret;
+ }
+ else if (tcontains[0] == tcontains[1] && tcontains[0] != ret)
+ {
+ // if str.contains( x, y ) ---> t and str.contains( y, x ) ---> t,
+ // then x = y ---> t
+ new_ret = tcontains[0];
+ debugExtendedRewrite(ret, new_ret, "eq-dual-contains-eq");
+ return new_ret;
+ }
+
+ std::vector<Node> c[2];
+ for (unsigned i = 0; i < 2; i++)
+ {
+ strings::TheoryStringsRewriter::getConcat(ret[i], c[i]);
+ }
+
+ bool changed = false;
+ for (unsigned i = 0; i < 2; i++)
+ {
+ while (!c[0].empty() && !c[1].empty() && c[0].back() == c[1].back())
+ {
+ c[0].pop_back();
+ c[1].pop_back();
+ changed = true;
+ }
+ // splice constants
+ if (!c[0].empty() && !c[1].empty() && c[0].back().isConst()
+ && c[1].back().isConst())
+ {
+ String cs[2];
+ for (unsigned j = 0; j < 2; j++)
+ {
+ cs[j] = c[j].back().getConst<String>();
+ }
+ unsigned larger = cs[0].size() > cs[1].size() ? 0 : 1;
+ unsigned smallerSize = cs[1 - larger].size();
+ if (cs[1 - larger]
+ == (i == 0 ? cs[larger].suffix(smallerSize)
+ : cs[larger].prefix(smallerSize)))
+ {
+ unsigned sizeDiff = cs[larger].size() - smallerSize;
+ c[larger][c[larger].size() - 1] =
+ nm->mkConst(i == 0 ? cs[larger].prefix(sizeDiff)
+ : cs[larger].suffix(sizeDiff));
+ c[1 - larger].pop_back();
+ changed = true;
+ }
+ }
+ for (unsigned j = 0; j < 2; j++)
+ {
+ std::reverse(c[j].begin(), c[j].end());
+ }
+ }
+ if (changed)
+ {
+ // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y
+ Node s1 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[0]);
+ Node s2 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[1]);
+ new_ret = s1.eqNode(s2);
+ debugExtendedRewrite(ret, new_ret, "string-eq-unify");
+ return new_ret;
+ }
+
+ // homogeneous constants
+ if (d_aggr)
+ {
+ for (unsigned i = 0; i < 2; i++)
+ {
+ if (ret[i].isConst())
+ {
+ bool isHomogeneous = true;
+ std::vector<unsigned> vec = ret[i].getConst<String>().getVec();
+ if (vec.size() > 1)
+ {
+ unsigned hchar = vec[0];
+ for (unsigned j = 1, size = vec.size(); j < size; j++)
+ {
+ if (vec[j] != hchar)
+ {
+ isHomogeneous = false;
+ break;
+ }
+ }
+ }
+ if (isHomogeneous && !std::is_sorted(c[1-i].begin(),c[1-i].end()))
+ {
+ Node ss = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT,
+ c[1 - i]);
+ Assert(ss != ret[1 - i]);
+ // e.g. "AA" = x ++ y ---> "AA" = y ++ x if y < x
+ new_ret = ret[i].eqNode(ss);
+ debugExtendedRewrite(ret, new_ret, "string-eq-homog-const");
+ return new_ret;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return new_ret;
+}
+
void ExtendedRewriter::debugExtendedRewrite(Node n,
Node ret,
const char* c) const