Clean up explanations involving string length. Add regression.
[cvc5.git] / src / theory / strings / theory_strings.h
1 /********************* */
2 /*! \file theory_strings.h
3 ** \verbatim
4 ** Original author: Tianyi Liang
5 ** Major contributors: Andrew Reynolds
6 ** Minor contributors (to current version): Martin Brain <>, Morgan Deters
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Theory of strings
13 **
14 ** Theory of strings.
15 **/
16
17 #include "cvc4_private.h"
18
19 #ifndef __CVC4__THEORY__STRINGS__THEORY_STRINGS_H
20 #define __CVC4__THEORY__STRINGS__THEORY_STRINGS_H
21
22 #include "theory/theory.h"
23 #include "theory/uf/equality_engine.h"
24 #include "theory/strings/theory_strings_preprocess.h"
25 #include "theory/strings/regexp_operation.h"
26
27 #include "context/cdchunk_list.h"
28 #include "context/cdhashset.h"
29 #include "expr/attribute.h"
30
31 #include <climits>
32 #include <deque>
33
34 namespace CVC4 {
35 namespace theory {
36 namespace strings {
37
38 /**
39 * Decision procedure for strings.
40 *
41 */
42
43 struct StringsProxyVarAttributeId {};
44 typedef expr::Attribute< StringsProxyVarAttributeId, bool > StringsProxyVarAttribute;
45
46 class TheoryStrings : public Theory {
47 typedef context::CDChunkList<Node> NodeList;
48 typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
49 typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
50 typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
51 typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
52 typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
53
54 public:
55 TheoryStrings(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo);
56 ~TheoryStrings();
57
58 void setMasterEqualityEngine(eq::EqualityEngine* eq);
59
60 std::string identify() const { return std::string("TheoryStrings"); }
61
62 public:
63 void propagate(Effort e);
64 bool propagate(TNode literal);
65 void explain( TNode literal, std::vector<TNode>& assumptions );
66 Node explain( TNode literal );
67
68
69 // NotifyClass for equality engine
70 class NotifyClass : public eq::EqualityEngineNotify {
71 TheoryStrings& d_str;
72 public:
73 NotifyClass(TheoryStrings& t_str): d_str(t_str) {}
74 bool eqNotifyTriggerEquality(TNode equality, bool value) {
75 Debug("strings") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl;
76 if (value) {
77 return d_str.propagate(equality);
78 } else {
79 // We use only literal triggers so taking not is safe
80 return d_str.propagate(equality.notNode());
81 }
82 }
83 bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
84 Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false") << ")" << std::endl;
85 if (value) {
86 return d_str.propagate(predicate);
87 } else {
88 return d_str.propagate(predicate.notNode());
89 }
90 }
91 bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
92 Debug("strings") << "NotifyClass::eqNotifyTriggerTermMerge(" << tag << ", " << t1 << ", " << t2 << ")" << std::endl;
93 if (value) {
94 return d_str.propagate(t1.eqNode(t2));
95 } else {
96 return d_str.propagate(t1.eqNode(t2).notNode());
97 }
98 }
99 void eqNotifyConstantTermMerge(TNode t1, TNode t2) {
100 Debug("strings") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
101 d_str.conflict(t1, t2);
102 }
103 void eqNotifyNewClass(TNode t) {
104 Debug("strings") << "NotifyClass::eqNotifyNewClass(" << t << std::endl;
105 d_str.eqNotifyNewClass(t);
106 }
107 void eqNotifyPreMerge(TNode t1, TNode t2) {
108 Debug("strings") << "NotifyClass::eqNotifyPreMerge(" << t1 << ", " << t2 << std::endl;
109 d_str.eqNotifyPreMerge(t1, t2);
110 }
111 void eqNotifyPostMerge(TNode t1, TNode t2) {
112 Debug("strings") << "NotifyClass::eqNotifyPostMerge(" << t1 << ", " << t2 << std::endl;
113 d_str.eqNotifyPostMerge(t1, t2);
114 }
115 void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
116 Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl;
117 d_str.eqNotifyDisequal(t1, t2, reason);
118 }
119 };/* class TheoryStrings::NotifyClass */
120
121 private:
122 // Constants
123 Node d_emptyString;
124 Node d_emptyRegexp;
125 Node d_true;
126 Node d_false;
127 Node d_zero;
128 Node d_one;
129 CVC4::Rational RMAXINT;
130 unsigned d_card_size;
131 // Helper functions
132 Node getRepresentative( Node t );
133 bool hasTerm( Node a );
134 bool areEqual( Node a, Node b );
135 bool areDisequal( Node a, Node b );
136 // t is representative, te = t, add lt = te to explanation exp
137 Node getLengthExp( Node t, std::vector< Node >& exp, Node te );
138 Node getLength( Node t, std::vector< Node >& exp );
139
140 private:
141 /** The notify class */
142 NotifyClass d_notify;
143 /** Equaltity engine */
144 eq::EqualityEngine d_equalityEngine;
145 /** Are we in conflict */
146 context::CDO<bool> d_conflict;
147 //list of pairs of nodes to merge
148 std::map< Node, Node > d_pending_exp;
149 std::vector< Node > d_pending;
150 std::vector< Node > d_lemma_cache;
151 std::map< Node, bool > d_pending_req_phase;
152 /** inferences: maintained to ensure ref count for internally introduced nodes */
153 NodeList d_infer;
154 NodeList d_infer_exp;
155 /** normal forms */
156 std::map< Node, Node > d_normal_forms_base;
157 std::map< Node, std::vector< Node > > d_normal_forms;
158 std::map< Node, std::vector< Node > > d_normal_forms_exp;
159 //map of pairs of terms that have the same normal form
160 NodeListMap d_nf_pairs;
161 void addNormalFormPair( Node n1, Node n2 );
162 bool isNormalFormPair( Node n1, Node n2 );
163 bool isNormalFormPair2( Node n1, Node n2 );
164 // loop ant
165 NodeSet d_loop_antec;
166 NodeSet d_length_intro_vars;
167 // preReg cache
168 NodeSet d_registered_terms_cache;
169 // preprocess cache
170 StringsPreprocess d_preproc;
171 NodeBoolMap d_preproc_cache;
172 // extended functions inferences cache
173 NodeSet d_extf_infer_cache;
174
175 private:
176 NodeSet d_congruent;
177 std::map< Node, Node > d_eqc_to_const;
178 std::map< Node, Node > d_eqc_to_const_base;
179 std::map< Node, Node > d_eqc_to_const_exp;
180 std::map< Node, Node > d_eqc_to_len_term;
181 std::vector< Node > d_strings_eqc;
182 Node d_emptyString_r;
183 class TermIndex {
184 public:
185 Node d_data;
186 std::map< Node, TermIndex > d_children;
187 Node add( Node n, unsigned index, TheoryStrings* t, Node er, std::vector< Node >& c );
188 void clear(){ d_children.clear(); }
189 };
190 std::map< Kind, TermIndex > d_term_index;
191 //list of non-congruent concat terms in each eqc
192 std::map< Node, std::vector< Node > > d_eqc;
193 std::map< Node, std::vector< Node > > d_flat_form;
194 std::map< Node, std::vector< int > > d_flat_form_index;
195
196 void debugPrintFlatForms( const char * tc );
197 /////////////////////////////////////////////////////////////////////////////
198 // MODEL GENERATION
199 /////////////////////////////////////////////////////////////////////////////
200 public:
201 void collectModelInfo(TheoryModel* m, bool fullModel);
202
203 /////////////////////////////////////////////////////////////////////////////
204 // NOTIFICATIONS
205 /////////////////////////////////////////////////////////////////////////////
206 public:
207 void presolve();
208 void shutdown() { }
209
210 /////////////////////////////////////////////////////////////////////////////
211 // MAIN SOLVER
212 /////////////////////////////////////////////////////////////////////////////
213 private:
214 void addSharedTerm(TNode n);
215 EqualityStatus getEqualityStatus(TNode a, TNode b);
216
217 private:
218 class EqcInfo {
219 public:
220 EqcInfo( context::Context* c );
221 ~EqcInfo(){}
222 //constant in this eqc
223 context::CDO< Node > d_const_term;
224 context::CDO< Node > d_length_term;
225 context::CDO< unsigned > d_cardinality_lem_k;
226 // 1 = added length lemma
227 context::CDO< Node > d_normalized_length;
228 };
229 /** map from representatives to information necessary for equivalence classes */
230 std::map< Node, EqcInfo* > d_eqc_info;
231 EqcInfo * getOrMakeEqcInfo( Node eqc, bool doMake = true );
232 //maintain which concat terms have the length lemma instantiated
233 NodeNodeMap d_proxy_var;
234 NodeNodeMap d_proxy_var_to_length;
235 private:
236
237 //initial check
238 void checkInit();
239 void checkConstantEquivalenceClasses( TermIndex* ti, std::vector< Node >& vecc );
240 //extended functions evaluation check
241 void checkExtendedFuncsEval( int effort = 0 );
242 void checkExtfInference( Node n, Node nr, int effort );
243 void collectVars( Node n, std::map< Node, std::vector< Node > >& vars, std::map< Node, bool >& visited );
244 Node getSymbolicDefinition( Node n, std::vector< Node >& exp );
245 //check extf reduction
246 void checkExtfReduction( int effort );
247 void checkReduction( Node atom, int pol, int effort );
248 //flat forms check
249 void checkFlatForms();
250 Node checkCycles( Node eqc, std::vector< Node >& curr, std::vector< Node >& exp );
251 //normal forms check
252 void checkNormalForms();
253 void mergeCstVec(std::vector< Node > &vec_strings);
254 bool getNormalForms(Node &eqc, std::vector< Node > & visited, std::vector< Node > & nf,
255 std::vector< std::vector< Node > > &normal_forms,
256 std::vector< std::vector< Node > > &normal_forms_exp,
257 std::vector< Node > &normal_form_src);
258 bool detectLoop(std::vector< std::vector< Node > > &normal_forms,
259 int i, int j, int index_i, int index_j,
260 int &loop_in_i, int &loop_in_j);
261 bool processLoop(std::vector< Node > &antec,
262 std::vector< std::vector< Node > > &normal_forms,
263 std::vector< Node > &normal_form_src,
264 int i, int j, int loop_n_index, int other_n_index,
265 int loop_index, int index, int other_index);
266 bool processNEqc(std::vector< std::vector< Node > > &normal_forms,
267 std::vector< std::vector< Node > > &normal_forms_exp,
268 std::vector< Node > &normal_form_src);
269 bool processReverseNEq(std::vector< std::vector< Node > > &normal_forms,
270 std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j );
271 bool processSimpleNEq( std::vector< std::vector< Node > > &normal_forms,
272 std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j,
273 unsigned& index_i, unsigned& index_j, bool isRev );
274 bool normalizeEquivalenceClass( Node n, std::vector< Node > & visited, std::vector< Node > & nf, std::vector< Node > & nf_exp );
275 bool processDeq( Node n1, Node n2 );
276 int processReverseDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj );
277 int processSimpleDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj, unsigned& index, bool isRev );
278 void checkDeqNF();
279
280 //check for extended functions
281 void checkExtendedFuncs();
282 //check membership constraints
283 Node mkRegExpAntec(Node atom, Node ant);
284 Node normalizeRegexp(Node r);
285 bool normalizePosMemberships(std::map< Node, std::vector< Node > > &memb_with_exps);
286 bool applyRConsume( CVC4::String &s, Node &r);
287 Node applyRSplit(Node s1, Node s2, Node r);
288 bool applyRLen(std::map< Node, std::vector< Node > > &XinR_with_exps);
289 bool checkMembershipsWithoutLength(
290 std::map< Node, std::vector< Node > > &memb_with_exps,
291 std::map< Node, std::vector< Node > > &XinR_with_exps);
292 void checkMemberships();
293 bool checkMemberships2();
294 bool checkPDerivative(Node x, Node r, Node atom, bool &addedLemma,
295 std::vector< Node > &processed, std::vector< Node > &cprocessed,
296 std::vector< Node > &nf_exp);
297 //check contains
298 void checkPosContains( std::vector< Node >& posContains );
299 void checkNegContains( std::vector< Node >& negContains );
300 //lengths normalize check
301 void checkLengthsEqc();
302 //cardinality check
303 void checkCardinality();
304
305 public:
306 /** preregister term */
307 void preRegisterTerm(TNode n);
308 /** Expand definition */
309 Node expandDefinition(LogicRequest &logicRequest, Node n);
310 /** Check at effort e */
311 void check(Effort e);
312 /** Conflict when merging two constants */
313 void conflict(TNode a, TNode b);
314 /** called when a new equivalence class is created */
315 void eqNotifyNewClass(TNode t);
316 /** called when two equivalence classes will merge */
317 void eqNotifyPreMerge(TNode t1, TNode t2);
318 /** called when two equivalence classes have merged */
319 void eqNotifyPostMerge(TNode t1, TNode t2);
320 /** called when two equivalence classes are made disequal */
321 void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
322 /** get preprocess */
323 StringsPreprocess * getPreprocess() { return &d_preproc; }
324 protected:
325 /** compute care graph */
326 void computeCareGraph();
327
328 //do pending merges
329 void assertPendingFact(Node atom, bool polarity, Node exp);
330 void doPendingFacts();
331 void doPendingLemmas();
332 bool hasProcessed();
333 void addToExplanation( Node a, Node b, std::vector< Node >& exp );
334 void addToExplanation( Node lit, std::vector< Node >& exp );
335
336 //register term
337 bool registerTerm( Node n );
338 //send lemma
339 void sendLemma( Node ant, Node conc, const char * c );
340 void sendInfer( Node eq_exp, Node eq, const char * c );
341 void sendSplit( Node a, Node b, const char * c, bool preq = true );
342 void sendLengthLemma( Node n );
343 /** mkConcat **/
344 inline Node mkConcat( Node n1, Node n2 );
345 inline Node mkConcat( Node n1, Node n2, Node n3 );
346 inline Node mkConcat( const std::vector< Node >& c );
347 inline Node mkLength( Node n );
348 //mkSkolem
349 inline Node mkSkolemS(const char * c, int isLenSplit = 0);
350 //inline Node mkSkolemI(const char * c);
351 /** mkExplain **/
352 Node mkExplain( std::vector< Node >& a );
353 Node mkExplain( std::vector< Node >& a, std::vector< Node >& an );
354 /** mkAnd **/
355 Node mkAnd( std::vector< Node >& a );
356 /** get concat vector */
357 void getConcatVec( Node n, std::vector< Node >& c );
358
359 //get equivalence classes
360 void getEquivalenceClasses( std::vector< Node >& eqcs );
361 //get final normal form
362 void getFinalNormalForm( Node n, std::vector< Node >& nf, std::vector< Node >& exp );
363
364 //separate into collections with equal length
365 void separateByLength( std::vector< Node >& n, std::vector< std::vector< Node > >& col, std::vector< Node >& lts );
366 void printConcat( std::vector< Node >& n, const char * c );
367
368 void inferSubstitutionProxyVars( Node n, std::vector< Node >& vars, std::vector< Node >& subs, std::vector< Node >& unproc );
369
370 enum {
371 sk_id_c_spt,
372 sk_id_vc_spt,
373 sk_id_v_spt,
374 sk_id_ctn_pre,
375 sk_id_ctn_post,
376 sk_id_deq_x,
377 sk_id_deq_y,
378 sk_id_deq_z,
379 };
380 std::map< Node, std::map< Node, std::map< int, Node > > > d_skolem_cache;
381 Node mkSkolemCached( Node a, Node b, int id, const char * c, int isLenSplit = 0 );
382 private:
383
384 // Special String Functions
385 NodeSet d_neg_ctn_eqlen;
386 NodeSet d_neg_ctn_ulen;
387 NodeSet d_neg_ctn_cached;
388 //extended string terms and whether they have been reduced
389 NodeBoolMap d_ext_func_terms;
390 std::map< Node, std::map< Node, std::vector< Node > > > d_extf_vars;
391 // list of terms that something (does not) contain and their explanation
392 class ExtfInfo {
393 public:
394 std::map< bool, std::vector< Node > > d_ctn;
395 std::map< bool, std::vector< Node > > d_ctn_from;
396 };
397 std::map< Node, int > d_extf_pol;
398 std::map< Node, std::vector< Node > > d_extf_exp;
399 std::map< Node, ExtfInfo > d_extf_info;
400 //collect extended operator terms
401 void collectExtendedFuncTerms( Node n, std::map< Node, bool >& visited );
402
403 // Symbolic Regular Expression
404 private:
405 // regular expression memberships
406 NodeList d_regexp_memberships;
407 NodeSet d_regexp_ucached;
408 NodeSet d_regexp_ccached;
409 // stored assertions
410 NodeListMap d_pos_memberships;
411 NodeListMap d_neg_memberships;
412 // semi normal forms for symbolic expression
413 std::map< Node, Node > d_nf_regexps;
414 std::map< Node, std::vector< Node > > d_nf_regexps_exp;
415 // intersection
416 NodeNodeMap d_inter_cache;
417 NodeIntMap d_inter_index;
418 // processed memberships
419 NodeSet d_processed_memberships;
420 // antecedant for why regexp membership must be true
421 NodeNodeMap d_regexp_ant;
422 // membership length
423 //std::map< Node, bool > d_membership_length;
424 // regular expression operations
425 RegExpOpr d_regexp_opr;
426
427 CVC4::String getHeadConst( Node x );
428 bool deriveRegExp( Node x, Node r, Node ant );
429 bool addMembershipLength(Node atom);
430 void addMembership(Node assertion);
431 Node getNormalString(Node x, std::vector<Node> &nf_exp);
432 Node getNormalSymRegExp(Node r, std::vector<Node> &nf_exp);
433
434
435 // Finite Model Finding
436 private:
437 NodeSet d_input_vars;
438 context::CDO< Node > d_input_var_lsum;
439 context::CDHashMap< int, Node > d_cardinality_lits;
440 context::CDO< int > d_curr_cardinality;
441 public:
442 //for finite model finding
443 Node getNextDecisionRequest();
444
445 public:
446 /** statistics class */
447 class Statistics {
448 public:
449 IntStat d_splits;
450 IntStat d_eq_splits;
451 IntStat d_deq_splits;
452 IntStat d_loop_lemmas;
453 IntStat d_new_skolems;
454 Statistics();
455 ~Statistics();
456 };/* class TheoryStrings::Statistics */
457 Statistics d_statistics;
458 };/* class TheoryStrings */
459
460 }/* CVC4::theory::strings namespace */
461 }/* CVC4::theory namespace */
462 }/* CVC4 namespace */
463
464 #endif /* __CVC4__THEORY__STRINGS__THEORY_STRINGS_H */