55539c802e39b5f35a44b8efbe8465be8e6763de
[cvc5.git] / src / theory / strings / solver_state.cpp
1 /********************* */
2 /*! \file solver_state.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of the solver state of the theory of strings.
13 **/
14
15 #include "theory/strings/solver_state.h"
16
17 #include "theory/strings/theory_strings_utils.h"
18 #include "theory/strings/word.h"
19
20 using namespace std;
21 using namespace CVC4::context;
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace strings {
27
28 SolverState::SolverState(context::Context* c,
29 eq::EqualityEngine& ee,
30 Valuation& v)
31 : d_context(c),
32 d_ee(ee),
33 d_eeDisequalities(c),
34 d_valuation(v),
35 d_conflict(c, false),
36 d_pendingConflict(c)
37 {
38 d_zero = NodeManager::currentNM()->mkConst(Rational(0));
39 }
40
41 SolverState::~SolverState()
42 {
43 for (std::pair<const Node, EqcInfo*>& it : d_eqcInfo)
44 {
45 delete it.second;
46 }
47 }
48
49 Node SolverState::getRepresentative(Node t) const
50 {
51 if (d_ee.hasTerm(t))
52 {
53 return d_ee.getRepresentative(t);
54 }
55 return t;
56 }
57
58 bool SolverState::hasTerm(Node a) const { return d_ee.hasTerm(a); }
59
60 bool SolverState::areEqual(Node a, Node b) const
61 {
62 if (a == b)
63 {
64 return true;
65 }
66 else if (hasTerm(a) && hasTerm(b))
67 {
68 return d_ee.areEqual(a, b);
69 }
70 return false;
71 }
72
73 bool SolverState::areDisequal(Node a, Node b) const
74 {
75 if (a == b)
76 {
77 return false;
78 }
79 else if (hasTerm(a) && hasTerm(b))
80 {
81 Node ar = d_ee.getRepresentative(a);
82 Node br = d_ee.getRepresentative(b);
83 return (ar != br && ar.isConst() && br.isConst())
84 || d_ee.areDisequal(ar, br, false);
85 }
86 Node ar = getRepresentative(a);
87 Node br = getRepresentative(b);
88 return ar != br && ar.isConst() && br.isConst();
89 }
90
91 eq::EqualityEngine* SolverState::getEqualityEngine() const { return &d_ee; }
92
93 const context::CDList<Node>& SolverState::getDisequalityList() const
94 {
95 return d_eeDisequalities;
96 }
97
98 void SolverState::eqNotifyNewClass(TNode t)
99 {
100 Kind k = t.getKind();
101 if (k == STRING_LENGTH || k == STRING_TO_CODE)
102 {
103 Node r = d_ee.getRepresentative(t[0]);
104 EqcInfo* ei = getOrMakeEqcInfo(r);
105 if (k == STRING_LENGTH)
106 {
107 ei->d_lengthTerm = t[0];
108 }
109 else
110 {
111 ei->d_codeTerm = t[0];
112 }
113 }
114 else if (t.isConst())
115 {
116 EqcInfo* ei = getOrMakeEqcInfo(t);
117 ei->d_prefixC = t;
118 ei->d_suffixC = t;
119 return;
120 }
121 else if (k == STRING_CONCAT)
122 {
123 addEndpointsToEqcInfo(t, t, t);
124 }
125 }
126
127 void SolverState::eqNotifyPreMerge(TNode t1, TNode t2)
128 {
129 EqcInfo* e2 = getOrMakeEqcInfo(t2, false);
130 if (e2)
131 {
132 EqcInfo* e1 = getOrMakeEqcInfo(t1);
133 // add information from e2 to e1
134 if (!e2->d_lengthTerm.get().isNull())
135 {
136 e1->d_lengthTerm.set(e2->d_lengthTerm);
137 }
138 if (!e2->d_codeTerm.get().isNull())
139 {
140 e1->d_codeTerm.set(e2->d_codeTerm);
141 }
142 if (!e2->d_prefixC.get().isNull())
143 {
144 setPendingConflictWhen(
145 e1->addEndpointConst(e2->d_prefixC, Node::null(), false));
146 }
147 if (!e2->d_suffixC.get().isNull())
148 {
149 setPendingConflictWhen(
150 e1->addEndpointConst(e2->d_suffixC, Node::null(), true));
151 }
152 if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
153 {
154 e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
155 }
156 if (!e2->d_normalizedLength.get().isNull())
157 {
158 e1->d_normalizedLength.set(e2->d_normalizedLength);
159 }
160 }
161 }
162
163 void SolverState::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
164 {
165 if (t1.getType().isStringLike())
166 {
167 // store disequalities between strings, may need to check if their lengths
168 // are equal/disequal
169 d_eeDisequalities.push_back(t1.eqNode(t2));
170 }
171 }
172
173 EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake)
174 {
175 std::map<Node, EqcInfo*>::iterator eqc_i = d_eqcInfo.find(eqc);
176 if (eqc_i != d_eqcInfo.end())
177 {
178 return eqc_i->second;
179 }
180 if (doMake)
181 {
182 EqcInfo* ei = new EqcInfo(d_context);
183 d_eqcInfo[eqc] = ei;
184 return ei;
185 }
186 return nullptr;
187 }
188
189 TheoryModel* SolverState::getModel() const { return d_valuation.getModel(); }
190
191 void SolverState::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
192 {
193 Assert(concat.getKind() == STRING_CONCAT
194 || concat.getKind() == REGEXP_CONCAT);
195 EqcInfo* ei = nullptr;
196 // check each side
197 for (unsigned r = 0; r < 2; r++)
198 {
199 unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1;
200 Node c = utils::getConstantComponent(concat[index]);
201 if (!c.isNull())
202 {
203 if (ei == nullptr)
204 {
205 ei = getOrMakeEqcInfo(eqc);
206 }
207 Trace("strings-eager-pconf-debug")
208 << "New term: " << concat << " for " << t << " with prefix " << c
209 << " (" << (r == 1) << ")" << std::endl;
210 setPendingConflictWhen(ei->addEndpointConst(t, c, r == 1));
211 }
212 }
213 }
214
215 Node SolverState::getLengthExp(Node t, std::vector<Node>& exp, Node te)
216 {
217 Assert(areEqual(t, te));
218 Node lt = utils::mkNLength(te);
219 if (hasTerm(lt))
220 {
221 // use own length if it exists, leads to shorter explanation
222 return lt;
223 }
224 EqcInfo* ei = getOrMakeEqcInfo(t, false);
225 Node lengthTerm = ei ? ei->d_lengthTerm : Node::null();
226 if (lengthTerm.isNull())
227 {
228 // typically shouldnt be necessary
229 lengthTerm = t;
230 }
231 Debug("strings") << "SolverState::getLengthTerm " << t << " is " << lengthTerm
232 << std::endl;
233 if (te != lengthTerm)
234 {
235 exp.push_back(te.eqNode(lengthTerm));
236 }
237 return Rewriter::rewrite(
238 NodeManager::currentNM()->mkNode(STRING_LENGTH, lengthTerm));
239 }
240
241 Node SolverState::getLength(Node t, std::vector<Node>& exp)
242 {
243 return getLengthExp(t, exp, t);
244 }
245
246 Node SolverState::explainNonEmpty(Node s)
247 {
248 Assert(s.getType().isStringLike());
249 Node emp = Word::mkEmptyWord(s.getType());
250 if (areDisequal(s, emp))
251 {
252 return s.eqNode(emp).negate();
253 }
254 Node sLen = utils::mkNLength(s);
255 if (areDisequal(sLen, d_zero))
256 {
257 return sLen.eqNode(d_zero).negate();
258 }
259 return Node::null();
260 }
261
262 void SolverState::setConflict() { d_conflict = true; }
263 bool SolverState::isInConflict() const { return d_conflict; }
264
265 void SolverState::setPendingConflictWhen(Node conf)
266 {
267 if (!conf.isNull() && d_pendingConflict.get().isNull())
268 {
269 d_pendingConflict = conf;
270 }
271 }
272
273 Node SolverState::getPendingConflict() const { return d_pendingConflict; }
274
275 std::pair<bool, Node> SolverState::entailmentCheck(options::TheoryOfMode mode,
276 TNode lit)
277 {
278 return d_valuation.entailmentCheck(mode, lit);
279 }
280
281 void SolverState::separateByLength(const std::vector<Node>& n,
282 std::vector<std::vector<Node> >& cols,
283 std::vector<Node>& lts)
284 {
285 unsigned leqc_counter = 0;
286 std::map<Node, unsigned> eqc_to_leqc;
287 std::map<unsigned, Node> leqc_to_eqc;
288 std::map<unsigned, std::vector<Node> > eqc_to_strings;
289 NodeManager* nm = NodeManager::currentNM();
290 for (const Node& eqc : n)
291 {
292 Assert(d_ee.getRepresentative(eqc) == eqc);
293 EqcInfo* ei = getOrMakeEqcInfo(eqc, false);
294 Node lt = ei ? ei->d_lengthTerm : Node::null();
295 if (!lt.isNull())
296 {
297 lt = nm->mkNode(STRING_LENGTH, lt);
298 Node r = d_ee.getRepresentative(lt);
299 if (eqc_to_leqc.find(r) == eqc_to_leqc.end())
300 {
301 eqc_to_leqc[r] = leqc_counter;
302 leqc_to_eqc[leqc_counter] = r;
303 leqc_counter++;
304 }
305 eqc_to_strings[eqc_to_leqc[r]].push_back(eqc);
306 }
307 else
308 {
309 eqc_to_strings[leqc_counter].push_back(eqc);
310 leqc_counter++;
311 }
312 }
313 for (const std::pair<const unsigned, std::vector<Node> >& p : eqc_to_strings)
314 {
315 cols.push_back(std::vector<Node>());
316 cols.back().insert(cols.back().end(), p.second.begin(), p.second.end());
317 lts.push_back(leqc_to_eqc[p.first]);
318 }
319 }
320
321 } // namespace strings
322 } // namespace theory
323 } // namespace CVC4