Refactoring of inferences in strings. Add several options.
[cvc5.git] / src / theory / strings / theory_strings.h
1 /********************* */
2 /*! \file theory_strings.h
3 ** \verbatim
4 ** Original author: Tianyi Liang
5 ** Major contributors: Andrew Reynolds
6 ** Minor contributors (to current version): Martin Brain <>, Morgan Deters
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Theory of strings
13 **
14 ** Theory of strings.
15 **/
16
17 #include "cvc4_private.h"
18
19 #ifndef __CVC4__THEORY__STRINGS__THEORY_STRINGS_H
20 #define __CVC4__THEORY__STRINGS__THEORY_STRINGS_H
21
22 #include "theory/theory.h"
23 #include "theory/uf/equality_engine.h"
24 #include "theory/strings/theory_strings_preprocess.h"
25 #include "theory/strings/regexp_operation.h"
26
27 #include "context/cdchunk_list.h"
28 #include "context/cdhashset.h"
29 #include "expr/attribute.h"
30
31 #include <climits>
32 #include <deque>
33
34 namespace CVC4 {
35 namespace theory {
36 namespace strings {
37
38 /**
39 * Decision procedure for strings.
40 *
41 */
42
43 struct StringsProxyVarAttributeId {};
44 typedef expr::Attribute< StringsProxyVarAttributeId, bool > StringsProxyVarAttribute;
45
46 class TheoryStrings : public Theory {
47 typedef context::CDChunkList<Node> NodeList;
48 typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
49 typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
50 typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
51 typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
52 typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
53
54 public:
55 TheoryStrings(context::Context* c, context::UserContext* u,
56 OutputChannel& out, Valuation valuation,
57 const LogicInfo& logicInfo);
58 ~TheoryStrings();
59
60 void setMasterEqualityEngine(eq::EqualityEngine* eq);
61
62 std::string identify() const { return std::string("TheoryStrings"); }
63
64 public:
65 void propagate(Effort e);
66 bool propagate(TNode literal);
67 void explain( TNode literal, std::vector<TNode>& assumptions );
68 Node explain( TNode literal );
69
70
71 // NotifyClass for equality engine
72 class NotifyClass : public eq::EqualityEngineNotify {
73 TheoryStrings& d_str;
74 public:
75 NotifyClass(TheoryStrings& t_str): d_str(t_str) {}
76 bool eqNotifyTriggerEquality(TNode equality, bool value) {
77 Debug("strings") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl;
78 if (value) {
79 return d_str.propagate(equality);
80 } else {
81 // We use only literal triggers so taking not is safe
82 return d_str.propagate(equality.notNode());
83 }
84 }
85 bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
86 Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false") << ")" << std::endl;
87 if (value) {
88 return d_str.propagate(predicate);
89 } else {
90 return d_str.propagate(predicate.notNode());
91 }
92 }
93 bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
94 Debug("strings") << "NotifyClass::eqNotifyTriggerTermMerge(" << tag << ", " << t1 << ", " << t2 << ")" << std::endl;
95 if (value) {
96 return d_str.propagate(t1.eqNode(t2));
97 } else {
98 return d_str.propagate(t1.eqNode(t2).notNode());
99 }
100 }
101 void eqNotifyConstantTermMerge(TNode t1, TNode t2) {
102 Debug("strings") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
103 d_str.conflict(t1, t2);
104 }
105 void eqNotifyNewClass(TNode t) {
106 Debug("strings") << "NotifyClass::eqNotifyNewClass(" << t << std::endl;
107 d_str.eqNotifyNewClass(t);
108 }
109 void eqNotifyPreMerge(TNode t1, TNode t2) {
110 Debug("strings") << "NotifyClass::eqNotifyPreMerge(" << t1 << ", " << t2 << std::endl;
111 d_str.eqNotifyPreMerge(t1, t2);
112 }
113 void eqNotifyPostMerge(TNode t1, TNode t2) {
114 Debug("strings") << "NotifyClass::eqNotifyPostMerge(" << t1 << ", " << t2 << std::endl;
115 d_str.eqNotifyPostMerge(t1, t2);
116 }
117 void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
118 Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl;
119 d_str.eqNotifyDisequal(t1, t2, reason);
120 }
121 };/* class TheoryStrings::NotifyClass */
122
123 private:
124 // Constants
125 Node d_emptyString;
126 Node d_emptyRegexp;
127 Node d_true;
128 Node d_false;
129 Node d_zero;
130 Node d_one;
131 CVC4::Rational RMAXINT;
132 unsigned d_card_size;
133 // Helper functions
134 Node getRepresentative( Node t );
135 bool hasTerm( Node a );
136 bool areEqual( Node a, Node b );
137 bool areDisequal( Node a, Node b );
138 // t is representative, te = t, add lt = te to explanation exp
139 Node getLengthExp( Node t, std::vector< Node >& exp, Node te );
140 Node getLength( Node t, std::vector< Node >& exp );
141
142 private:
143 /** The notify class */
144 NotifyClass d_notify;
145 /** Equaltity engine */
146 eq::EqualityEngine d_equalityEngine;
147 /** Are we in conflict */
148 context::CDO<bool> d_conflict;
149 //list of pairs of nodes to merge
150 std::map< Node, Node > d_pending_exp;
151 std::vector< Node > d_pending;
152 std::vector< Node > d_lemma_cache;
153 std::map< Node, bool > d_pending_req_phase;
154 /** inferences: maintained to ensure ref count for internally introduced nodes */
155 NodeList d_infer;
156 NodeList d_infer_exp;
157 /** normal forms */
158 std::map< Node, Node > d_normal_forms_base;
159 std::map< Node, std::vector< Node > > d_normal_forms;
160 std::map< Node, std::vector< Node > > d_normal_forms_exp;
161 //map of pairs of terms that have the same normal form
162 NodeListMap d_nf_pairs;
163 void addNormalFormPair( Node n1, Node n2 );
164 bool isNormalFormPair( Node n1, Node n2 );
165 bool isNormalFormPair2( Node n1, Node n2 );
166 // loop ant
167 NodeSet d_loop_antec;
168 NodeSet d_length_intro_vars;
169 // preReg cache
170 NodeSet d_pregistered_terms_cache;
171 NodeSet d_registered_terms_cache;
172 // preprocess cache
173 StringsPreprocess d_preproc;
174 NodeBoolMap d_preproc_cache;
175 // extended functions inferences cache
176 NodeSet d_extf_infer_cache;
177 std::vector< Node > d_empty_vec;
178 private:
179 NodeSet d_congruent;
180 std::map< Node, Node > d_eqc_to_const;
181 std::map< Node, Node > d_eqc_to_const_base;
182 std::map< Node, Node > d_eqc_to_const_exp;
183 std::map< Node, Node > d_eqc_to_len_term;
184 std::vector< Node > d_strings_eqc;
185 Node d_emptyString_r;
186 class TermIndex {
187 public:
188 Node d_data;
189 std::map< Node, TermIndex > d_children;
190 Node add( Node n, unsigned index, TheoryStrings* t, Node er, std::vector< Node >& c );
191 void clear(){ d_children.clear(); }
192 };
193 std::map< Kind, TermIndex > d_term_index;
194 //list of non-congruent concat terms in each eqc
195 std::map< Node, std::vector< Node > > d_eqc;
196 std::map< Node, std::vector< Node > > d_flat_form;
197 std::map< Node, std::vector< int > > d_flat_form_index;
198
199 void debugPrintFlatForms( const char * tc );
200 /////////////////////////////////////////////////////////////////////////////
201 // MODEL GENERATION
202 /////////////////////////////////////////////////////////////////////////////
203 public:
204 void collectModelInfo(TheoryModel* m, bool fullModel);
205
206 /////////////////////////////////////////////////////////////////////////////
207 // NOTIFICATIONS
208 /////////////////////////////////////////////////////////////////////////////
209 public:
210 void presolve();
211 void shutdown() { }
212
213 /////////////////////////////////////////////////////////////////////////////
214 // MAIN SOLVER
215 /////////////////////////////////////////////////////////////////////////////
216 private:
217 void addSharedTerm(TNode n);
218 EqualityStatus getEqualityStatus(TNode a, TNode b);
219
220 private:
221 class EqcInfo {
222 public:
223 EqcInfo( context::Context* c );
224 ~EqcInfo(){}
225 //constant in this eqc
226 context::CDO< Node > d_const_term;
227 context::CDO< Node > d_length_term;
228 context::CDO< unsigned > d_cardinality_lem_k;
229 // 1 = added length lemma
230 context::CDO< Node > d_normalized_length;
231 };
232 /** map from representatives to information necessary for equivalence classes */
233 std::map< Node, EqcInfo* > d_eqc_info;
234 EqcInfo * getOrMakeEqcInfo( Node eqc, bool doMake = true );
235 //maintain which concat terms have the length lemma instantiated
236 NodeNodeMap d_proxy_var;
237 NodeNodeMap d_proxy_var_to_length;
238 private:
239
240 //initial check
241 void checkInit();
242 void checkConstantEquivalenceClasses( TermIndex* ti, std::vector< Node >& vecc );
243 //extended functions evaluation check
244 void checkExtendedFuncsEval( int effort = 0 );
245 void checkExtfInference( Node n, Node nr, int effort );
246 void collectVars( Node n, std::map< Node, std::vector< Node > >& vars, std::map< Node, bool >& visited );
247 Node getSymbolicDefinition( Node n, std::vector< Node >& exp );
248 //check extf reduction
249 void checkExtfReduction( int effort );
250 void checkReduction( Node atom, int pol, int effort );
251 //flat forms check
252 void checkFlatForms();
253 Node checkCycles( Node eqc, std::vector< Node >& curr, std::vector< Node >& exp );
254 //normal forms check
255 void checkNormalForms();
256 bool normalizeEquivalenceClass( Node n, std::vector< Node > & nf, std::vector< Node > & nf_exp );
257 bool getNormalForms( Node &eqc, std::vector< Node > & nf,
258 std::vector< std::vector< Node > > &normal_forms,
259 std::vector< std::vector< Node > > &normal_forms_exp,
260 std::vector< Node > &normal_form_src);
261 bool detectLoop(std::vector< std::vector< Node > > &normal_forms,
262 int i, int j, int index_i, int index_j,
263 int &loop_in_i, int &loop_in_j);
264 bool processLoop(std::vector< Node > &antec,
265 std::vector< std::vector< Node > > &normal_forms,
266 std::vector< Node > &normal_form_src,
267 int i, int j, int loop_n_index, int other_n_index,
268 int loop_index, int index, int other_index);
269 bool processNEqc(std::vector< std::vector< Node > > &normal_forms,
270 std::vector< std::vector< Node > > &normal_forms_exp,
271 std::vector< Node > &normal_form_src);
272 bool processReverseNEq(std::vector< std::vector< Node > > &normal_forms,
273 std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j );
274 bool processSimpleNEq( std::vector< std::vector< Node > > &normal_forms,
275 std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j,
276 unsigned& index_i, unsigned& index_j, bool isRev );
277 bool processDeq( Node n1, Node n2 );
278 int processReverseDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj );
279 int processSimpleDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj, unsigned& index, bool isRev );
280 void checkDeqNF();
281
282 //check for extended functions
283 void checkExtendedFuncs();
284 //check membership constraints
285 Node mkRegExpAntec(Node atom, Node ant);
286 Node normalizeRegexp(Node r);
287 bool normalizePosMemberships( std::map< Node, std::vector< Node > > &memb_with_exps );
288 bool applyRConsume( CVC4::String &s, Node &r );
289 Node applyRSplit( Node s1, Node s2, Node r );
290 bool applyRLen( std::map< Node, std::vector< Node > > &XinR_with_exps );
291 bool checkMembershipsWithoutLength( std::map< Node, std::vector< Node > > &memb_with_exps,
292 std::map< Node, std::vector< Node > > &XinR_with_exps);
293 void checkMemberships();
294 bool checkMemberships2();
295 bool checkPDerivative( Node x, Node r, Node atom, bool &addedLemma,
296 std::vector< Node > &processed, std::vector< Node > &cprocessed,
297 std::vector< Node > &nf_exp);
298 //check contains
299 void checkPosContains( std::vector< Node >& posContains );
300 void checkNegContains( std::vector< Node >& negContains );
301 //lengths normalize check
302 void checkLengthsEqc();
303 //cardinality check
304 void checkCardinality();
305
306 public:
307 /** preregister term */
308 void preRegisterTerm(TNode n);
309 /** Expand definition */
310 Node expandDefinition(LogicRequest &logicRequest, Node n);
311 /** Check at effort e */
312 void check(Effort e);
313 /** Conflict when merging two constants */
314 void conflict(TNode a, TNode b);
315 /** called when a new equivalence class is created */
316 void eqNotifyNewClass(TNode t);
317 /** called when two equivalence classes will merge */
318 void eqNotifyPreMerge(TNode t1, TNode t2);
319 /** called when two equivalence classes have merged */
320 void eqNotifyPostMerge(TNode t1, TNode t2);
321 /** called when two equivalence classes are made disequal */
322 void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
323 /** get preprocess */
324 StringsPreprocess * getPreprocess() { return &d_preproc; }
325 protected:
326 /** compute care graph */
327 void computeCareGraph();
328
329 //do pending merges
330 void assertPendingFact(Node atom, bool polarity, Node exp);
331 void doPendingFacts();
332 void doPendingLemmas();
333 bool hasProcessed();
334 void addToExplanation( Node a, Node b, std::vector< Node >& exp );
335 void addToExplanation( Node lit, std::vector< Node >& exp );
336
337 //register term
338 void registerTerm( Node n, int effort );
339 //send lemma
340 void sendInference( std::vector< Node >& exp, std::vector< Node >& exp_n, Node eq, const char * c, bool asLemma = false );
341 void sendInference( std::vector< Node >& exp, Node eq, const char * c, bool asLemma = false );
342 void sendLemma( Node ant, Node conc, const char * c );
343 void sendInfer( Node eq_exp, Node eq, const char * c );
344 void sendSplit( Node a, Node b, const char * c, bool preq = true );
345 void sendLengthLemma( Node n );
346 /** mkConcat **/
347 inline Node mkConcat( Node n1, Node n2 );
348 inline Node mkConcat( Node n1, Node n2, Node n3 );
349 inline Node mkConcat( const std::vector< Node >& c );
350 inline Node mkLength( Node n );
351 //mkSkolem
352 enum {
353 sk_id_c_spt,
354 sk_id_vc_spt,
355 sk_id_v_spt,
356 sk_id_ctn_pre,
357 sk_id_ctn_post,
358 sk_id_deq_x,
359 sk_id_deq_y,
360 sk_id_deq_z,
361 };
362 std::map< Node, std::map< Node, std::map< int, Node > > > d_skolem_cache;
363 Node mkSkolemCached( Node a, Node b, int id, const char * c, int isLenSplit = 0 );
364 inline Node mkSkolemS(const char * c, int isLenSplit = 0);
365 //inline Node mkSkolemI(const char * c);
366 /** mkExplain **/
367 Node mkExplain( std::vector< Node >& a );
368 Node mkExplain( std::vector< Node >& a, std::vector< Node >& an );
369 /** mkAnd **/
370 Node mkAnd( std::vector< Node >& a );
371 /** get concat vector */
372 void getConcatVec( Node n, std::vector< Node >& c );
373
374 //get equivalence classes
375 void getEquivalenceClasses( std::vector< Node >& eqcs );
376 //get final normal form
377 void getFinalNormalForm( Node n, std::vector< Node >& nf, std::vector< Node >& exp );
378
379 //separate into collections with equal length
380 void separateByLength( std::vector< Node >& n, std::vector< std::vector< Node > >& col, std::vector< Node >& lts );
381 void printConcat( std::vector< Node >& n, const char * c );
382
383 void inferSubstitutionProxyVars( Node n, std::vector< Node >& vars, std::vector< Node >& subs, std::vector< Node >& unproc );
384 private:
385
386 // Special String Functions
387 NodeSet d_neg_ctn_eqlen;
388 NodeSet d_neg_ctn_ulen;
389 NodeSet d_neg_ctn_cached;
390 //extended string terms and whether they have been reduced
391 NodeBoolMap d_ext_func_terms;
392 std::map< Node, std::map< Node, std::vector< Node > > > d_extf_vars;
393 // list of terms that something (does not) contain and their explanation
394 class ExtfInfo {
395 public:
396 std::map< bool, std::vector< Node > > d_ctn;
397 std::map< bool, std::vector< Node > > d_ctn_from;
398 };
399 std::map< Node, int > d_extf_pol;
400 std::map< Node, std::vector< Node > > d_extf_exp;
401 std::map< Node, ExtfInfo > d_extf_info;
402 //collect extended operator terms
403 void collectExtendedFuncTerms( Node n, std::map< Node, bool >& visited );
404
405 // Symbolic Regular Expression
406 private:
407 // regular expression memberships
408 NodeList d_regexp_memberships;
409 NodeSet d_regexp_ucached;
410 NodeSet d_regexp_ccached;
411 // stored assertions
412 NodeListMap d_pos_memberships;
413 NodeListMap d_neg_memberships;
414 // semi normal forms for symbolic expression
415 std::map< Node, Node > d_nf_regexps;
416 std::map< Node, std::vector< Node > > d_nf_regexps_exp;
417 // intersection
418 NodeNodeMap d_inter_cache;
419 NodeIntMap d_inter_index;
420 // processed memberships
421 NodeSet d_processed_memberships;
422 // antecedant for why regexp membership must be true
423 NodeNodeMap d_regexp_ant;
424 // membership length
425 //std::map< Node, bool > d_membership_length;
426 // regular expression operations
427 RegExpOpr d_regexp_opr;
428
429 CVC4::String getHeadConst( Node x );
430 bool deriveRegExp( Node x, Node r, Node ant );
431 bool addMembershipLength(Node atom);
432 void addMembership(Node assertion);
433 Node getNormalString(Node x, std::vector<Node> &nf_exp);
434 Node getNormalSymRegExp(Node r, std::vector<Node> &nf_exp);
435
436
437 // Finite Model Finding
438 private:
439 NodeSet d_input_vars;
440 context::CDO< Node > d_input_var_lsum;
441 context::CDHashMap< int, Node > d_cardinality_lits;
442 context::CDO< int > d_curr_cardinality;
443 public:
444 //for finite model finding
445 Node getNextDecisionRequest();
446
447 public:
448 /** statistics class */
449 class Statistics {
450 public:
451 IntStat d_splits;
452 IntStat d_eq_splits;
453 IntStat d_deq_splits;
454 IntStat d_loop_lemmas;
455 IntStat d_new_skolems;
456 Statistics();
457 ~Statistics();
458 };/* class TheoryStrings::Statistics */
459 Statistics d_statistics;
460 };/* class TheoryStrings */
461
462 }/* CVC4::theory::strings namespace */
463 }/* CVC4::theory namespace */
464 }/* CVC4 namespace */
465
466 #endif /* __CVC4__THEORY__STRINGS__THEORY_STRINGS_H */