1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Andres Noetzli, Tianyi Liang
5 * This file is part of the cvc5 project.
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
13 * Base solver for the theory of strings. This class implements term
14 * indexing and constant inference for the theory of strings.
17 #include "theory/strings/base_solver.h"
19 #include "expr/sequence.h"
20 #include "options/strings_options.h"
21 #include "theory/rewriter.h"
22 #include "theory/strings/theory_strings_utils.h"
23 #include "theory/strings/word.h"
24 #include "util/rational.h"
27 using namespace cvc5::context
;
28 using namespace cvc5::kind
;
34 BaseSolver::BaseSolver(SolverState
& s
, InferenceManager
& im
)
35 : d_state(s
), d_im(im
), d_congruent(s
.getSatContext())
37 d_false
= NodeManager::currentNM()->mkConst(false);
38 d_cardSize
= utils::getAlphabetCardinality();
41 BaseSolver::~BaseSolver() {}
43 void BaseSolver::checkInit()
50 Trace("strings-base") << "BaseSolver::checkInit" << std::endl
;
51 // count of congruent, non-congruent per operator (independent of type),
53 std::map
<Kind
, std::pair
<uint32_t, uint32_t>> congruentCount
;
54 eq::EqualityEngine
* ee
= d_state
.getEqualityEngine();
55 eq::EqClassesIterator eqcs_i
= eq::EqClassesIterator(ee
);
56 while (!eqcs_i
.isFinished())
59 TypeNode tn
= eqc
.getType();
63 // get the term index for type tn
64 std::map
<Kind
, TermIndex
>& tti
= d_termIndex
[tn
];
65 if (tn
.isStringLike())
67 d_stringsEqc
.push_back(eqc
);
68 emps
= Word::mkEmptyWord(tn
);
71 eq::EqClassIterator eqc_i
= eq::EqClassIterator(eqc
, ee
);
72 while (!eqc_i
.isFinished())
76 Trace("strings-base") << "initialize term: " << n
<< std::endl
;
77 // process constant-like terms
78 if (utils::isConstantLike(n
))
80 Node prev
= d_eqcInfo
[eqc
].d_bestContent
;
83 // we have either (seq.unit x) = C, or (seq.unit x) = (seq.unit y)
84 // where C is a sequence constant.
86 prev
.isConst() ? prev
: (n
.isConst() ? n
: Node::null());
87 std::vector
<Node
> exp
;
88 exp
.push_back(prev
.eqNode(n
));
92 // injectivity of seq.unit
98 // should not have two constants in the same equivalence class
99 Assert(cval
.getType().isSequence());
100 std::vector
<Node
> cchars
= Word::getChars(cval
);
101 if (cchars
.size() == 1)
103 Node oval
= prev
.isConst() ? n
: prev
;
104 Assert(oval
.getKind() == SEQ_UNIT
);
106 t
= cchars
[0].getConst
<Sequence
>().getVec()[0];
107 // oval is congruent (ignored) in this context
108 d_congruent
.insert(oval
);
112 // (seq.unit x) = C => false if |C| != 1.
114 exp
, d_false
, InferenceId::STRINGS_UNIT_CONST_CONFLICT
);
118 if (!d_state
.areEqual(s
, t
))
120 // (seq.unit x) = (seq.unit y) => x=y, or
121 // (seq.unit x) = (seq.unit c) => x=c
122 Assert(s
.getType() == t
.getType());
123 d_im
.sendInference(exp
, s
.eqNode(t
), InferenceId::STRINGS_UNIT_INJ
);
126 // update best content
127 if (prev
.isNull() || n
.isConst())
129 d_eqcInfo
[eqc
].d_bestContent
= n
;
130 d_eqcInfo
[eqc
].d_bestScore
= 0;
131 d_eqcInfo
[eqc
].d_base
= n
;
132 d_eqcInfo
[eqc
].d_exp
= Node::null();
140 else if (n
.getNumChildren() > 0)
144 if (d_congruent
.find(n
) == d_congruent
.end())
147 Node nc
= tti
[k
].add(n
, 0, d_state
, emps
, c
);
150 Trace("strings-base-debug")
151 << "...found congruent term " << nc
<< std::endl
;
152 // check if we have inferred a new equality by removal of empty
154 if (k
== STRING_CONCAT
&& !d_state
.areEqual(nc
, n
))
156 std::vector
<Node
> exp
;
157 // the number of empty components of n, nc
158 size_t count
[2] = {0, 0};
159 while (count
[0] < nc
.getNumChildren()
160 || count
[1] < n
.getNumChildren())
162 // explain empty prefixes
163 for (unsigned t
= 0; t
< 2; t
++)
165 Node nn
= t
== 0 ? nc
: n
;
166 while (count
[t
] < nn
.getNumChildren()
167 && (nn
[count
[t
]] == emps
168 || d_state
.areEqual(nn
[count
[t
]], emps
)))
170 if (nn
[count
[t
]] != emps
)
172 exp
.push_back(nn
[count
[t
]].eqNode(emps
));
177 Trace("strings-base-debug")
178 << " counts = " << count
[0] << ", " << count
[1]
180 // explain equal components
181 if (count
[0] < nc
.getNumChildren())
183 Assert(count
[1] < n
.getNumChildren());
184 if (nc
[count
[0]] != n
[count
[1]])
186 exp
.push_back(nc
[count
[0]].eqNode(n
[count
[1]]));
192 // infer the equality
193 d_im
.sendInference(exp
, n
.eqNode(nc
), InferenceId::STRINGS_I_NORM
);
197 // We cannot mark one of the terms as reduced here (via
198 // ExtTheory::markCongruent) since extended function terms
199 // rely on reductions to other extended function terms. We
200 // may have a pair of extended function terms f(a)=f(b) where
201 // the reduction of argument a depends on the term b.
202 // Thus, marking f(b) as reduced by virtue of the fact we
203 // have f(a) is incorrect, since then we are effectively
204 // assuming that the reduction of f(a) depends on itself.
206 // this node is congruent to another one, we can ignore it
207 Trace("strings-base-debug")
208 << " congruent term : " << n
<< " (via " << nc
<< ")"
210 d_congruent
.insert(n
);
211 congruentCount
[k
].first
++;
213 else if (k
== STRING_CONCAT
&& c
.size() == 1)
215 Trace("strings-base-debug")
216 << " congruent term by singular : " << n
<< " " << c
[0]
219 if (!d_state
.areEqual(c
[0], n
))
222 std::vector
<Node
> exp
;
223 // explain empty components
224 bool foundNEmpty
= false;
225 for (const Node
& nnc
: n
)
227 if (d_state
.areEqual(nnc
, emps
))
231 exp
.push_back(nnc
.eqNode(emps
));
236 Assert(!foundNEmpty
);
241 AlwaysAssert(foundNEmpty
);
242 // infer the equality
243 d_im
.sendInference(exp
, n
.eqNode(ns
), InferenceId::STRINGS_I_NORM_S
);
245 d_congruent
.insert(n
);
246 congruentCount
[k
].first
++;
250 congruentCount
[k
].second
++;
255 congruentCount
[k
].first
++;
259 else if (!n
.isConst())
261 if (d_congruent
.find(n
) == d_congruent
.end())
263 // We mark all but the oldest variable in the equivalence class as
271 Trace("strings-base-debug")
272 << " congruent variable : " << var
<< std::endl
;
273 d_congruent
.insert(var
);
278 Trace("strings-base-debug")
279 << " congruent variable : " << n
<< std::endl
;
280 d_congruent
.insert(n
);
289 if (Trace
.isOn("strings-base"))
291 for (const std::pair
<const Kind
, std::pair
<uint32_t, uint32_t>>& cc
:
294 Trace("strings-base")
295 << " Terms[" << cc
.first
<< "] = " << cc
.second
.second
<< "/"
296 << (cc
.second
.first
+ cc
.second
.second
) << std::endl
;
299 Trace("strings-base") << "BaseSolver::checkInit finished" << std::endl
;
302 void BaseSolver::checkConstantEquivalenceClasses()
306 std::vector
<Node
> vecc
;
310 Trace("strings-base-debug")
311 << "Check constant equivalence classes..." << std::endl
;
312 prevSize
= d_eqcInfo
.size();
313 for (std::pair
<const TypeNode
, std::map
<Kind
, TermIndex
>>& tindex
:
316 checkConstantEquivalenceClasses(
317 &tindex
.second
[STRING_CONCAT
], vecc
, true);
319 } while (!d_im
.hasProcessed() && d_eqcInfo
.size() > prevSize
);
321 if (!d_im
.hasProcessed())
323 // now, go back and set "most content" terms
325 for (std::pair
<const TypeNode
, std::map
<Kind
, TermIndex
>>& tindex
:
328 checkConstantEquivalenceClasses(
329 &tindex
.second
[STRING_CONCAT
], vecc
, false);
334 void BaseSolver::checkConstantEquivalenceClasses(TermIndex
* ti
,
335 std::vector
<Node
>& vecc
,
342 // construct the constant if applicable
346 c
= utils::mkNConcat(vecc
, n
.getType());
348 if (!isConst
|| !d_state
.areEqual(n
, c
))
350 if (Trace
.isOn("strings-debug"))
352 Trace("strings-debug")
353 << "Constant eqc : " << c
<< " for " << n
<< std::endl
;
354 Trace("strings-debug") << " ";
355 for (const Node
& v
: vecc
)
357 Trace("strings-debug") << v
<< " ";
359 Trace("strings-debug") << std::endl
;
363 std::vector
<Node
> exp
;
364 // non-constant vector
365 std::vector
<Node
> vecnc
;
366 size_t contentSize
= 0;
367 while (count
< n
.getNumChildren())
369 // Add explanations for the empty children
371 while (count
< n
.getNumChildren()
372 && d_state
.isEqualEmptyWord(n
[count
], emps
))
374 d_im
.addToExplanation(n
[count
], emps
, exp
);
377 if (count
< n
.getNumChildren())
379 if (vecc
[countc
].isNull())
382 // no constant for this component, leave it as is
383 vecnc
.push_back(n
[count
]);
390 vecnc
.push_back(vecc
[countc
]);
391 Assert(vecc
[countc
].isConst());
392 contentSize
+= Word::getLength(vecc
[countc
]);
394 Trace("strings-debug") << "...explain " << n
[count
] << " "
395 << vecc
[countc
] << std::endl
;
396 if (!d_state
.areEqual(n
[count
], vecc
[countc
]))
398 Node nrr
= d_state
.getRepresentative(n
[count
]);
399 Assert(!d_eqcInfo
[nrr
].d_bestContent
.isNull()
400 && d_eqcInfo
[nrr
].d_bestContent
.isConst());
401 // must flatten to avoid nested AND in explanations
402 utils::flattenOp(AND
, d_eqcInfo
[nrr
].d_exp
, exp
);
403 // now explain equality to base
404 d_im
.addToExplanation(n
[count
], d_eqcInfo
[nrr
].d_base
, exp
);
408 d_im
.addToExplanation(n
[count
], vecc
[countc
], exp
);
415 // exp contains an explanation of n==c
416 Assert(!isConst
|| countc
== vecc
.size());
419 // no use storing something with no content
422 Node nr
= d_state
.getRepresentative(n
);
423 BaseEqcInfo
& bei
= d_eqcInfo
[nr
];
424 if (!bei
.d_bestContent
.isConst()
425 && (bei
.d_bestContent
.isNull() || contentSize
> bei
.d_bestScore
))
427 // The equivalence class is not entailed to be equal to a constant
428 // and we found a better concatenation
429 Node nct
= utils::mkNConcat(vecnc
, n
.getType());
430 Assert(!nct
.isConst());
431 bei
.d_bestContent
= nct
;
432 bei
.d_bestScore
= contentSize
;
436 bei
.d_exp
= utils::mkAnd(exp
);
438 Trace("strings-debug")
439 << "Set eqc best content " << n
<< " to " << nct
440 << ", explanation = " << bei
.d_exp
<< std::endl
;
444 else if (d_state
.hasTerm(c
))
446 d_im
.sendInference(exp
, n
.eqNode(c
), InferenceId::STRINGS_I_CONST_MERGE
);
449 else if (!d_im
.hasProcessed())
451 Node nr
= d_state
.getRepresentative(n
);
452 BaseEqcInfo
& bei
= d_eqcInfo
[nr
];
453 if (!bei
.d_bestContent
.isConst())
455 bei
.d_bestContent
= c
;
457 bei
.d_exp
= utils::mkAnd(exp
);
458 Trace("strings-debug")
459 << "Set eqc const " << n
<< " to " << c
460 << ", explanation = " << bei
.d_exp
<< std::endl
;
462 else if (c
!= bei
.d_bestContent
)
465 Trace("strings-debug")
466 << "Conflict, other constant was " << bei
.d_bestContent
467 << ", this constant was " << c
<< std::endl
;
468 if (bei
.d_exp
.isNull())
470 // n==c ^ n == c' => false
471 d_im
.addToExplanation(n
, bei
.d_bestContent
, exp
);
475 // n==c ^ n == d_base == c' => false
476 exp
.push_back(bei
.d_exp
);
477 d_im
.addToExplanation(n
, bei
.d_base
, exp
);
479 d_im
.sendInference(exp
, d_false
, InferenceId::STRINGS_I_CONST_CONFLICT
);
484 Trace("strings-debug") << "Duplicate constant." << std::endl
;
489 for (std::pair
<const TNode
, TermIndex
>& p
: ti
->d_children
)
491 std::map
<Node
, BaseEqcInfo
>::const_iterator it
= d_eqcInfo
.find(p
.first
);
492 if (it
!= d_eqcInfo
.end() && it
->second
.d_bestContent
.isConst())
494 vecc
.push_back(it
->second
.d_bestContent
);
495 checkConstantEquivalenceClasses(&p
.second
, vecc
, ensureConst
, isConst
);
498 else if (!ensureConst
)
500 // can still proceed, with null
501 vecc
.push_back(Node::null());
502 checkConstantEquivalenceClasses(&p
.second
, vecc
, ensureConst
, false);
505 if (d_im
.hasProcessed())
512 void BaseSolver::checkCardinality()
514 // This will create a partition of eqc, where each collection has length that
515 // are pairwise propagated to be equal. We do not require disequalities
516 // between the lengths of each collection, since we split on disequalities
517 // between lengths of string terms that are disequal (DEQ-LENGTH-SP).
518 std::map
<TypeNode
, std::vector
<std::vector
<Node
> > > cols
;
519 std::map
<TypeNode
, std::vector
<Node
> > lts
;
520 d_state
.separateByLength(d_stringsEqc
, cols
, lts
);
521 for (std::pair
<const TypeNode
, std::vector
<std::vector
<Node
> > >& c
: cols
)
523 checkCardinalityType(c
.first
, c
.second
, lts
[c
.first
]);
527 void BaseSolver::checkCardinalityType(TypeNode tn
,
528 std::vector
<std::vector
<Node
> >& cols
,
529 std::vector
<Node
>& lts
)
531 Trace("strings-card") << "Check cardinality (type " << tn
<< ")..."
533 NodeManager
* nm
= NodeManager::currentNM();
534 uint32_t typeCardSize
;
535 if (tn
.isString()) // string-only
537 typeCardSize
= d_cardSize
;
541 Assert(tn
.isSequence());
542 TypeNode etn
= tn
.getSequenceElementType();
543 if (!d_state
.isFiniteType(etn
))
545 // infinite cardinality, we are fine
548 // TODO (cvc4-projects #23): how to handle sequence for finite types?
551 // for each collection
552 for (unsigned i
= 0, csize
= cols
.size(); i
< csize
; ++i
)
555 Trace("strings-card") << "Number of strings with length equal to " << lr
556 << " is " << cols
[i
].size() << std::endl
;
557 if (cols
[i
].size() <= 1)
559 // no restriction on sets in the partition of size 1
563 unsigned card_need
= 1;
564 double curr
= static_cast<double>(cols
[i
].size());
565 while (curr
> typeCardSize
)
567 curr
= curr
/ static_cast<double>(typeCardSize
);
570 Trace("strings-card")
571 << "Need length " << card_need
572 << " for this number of strings (where alphabet size is "
573 << typeCardSize
<< ") given type " << tn
<< "." << std::endl
;
574 // check if we need to split
575 bool needsSplit
= true;
578 // if constant, compare
579 Node cmp
= nm
->mkNode(GEQ
, lr
, nm
->mkConst(Rational(card_need
)));
580 cmp
= Rewriter::rewrite(cmp
);
581 needsSplit
= !cmp
.getConst
<bool>();
585 // find the minimimum constant that we are unknown to be disequal from, or
586 // otherwise stop if we increment such that cardinality does not apply.
587 // We always start with r=1 since by the invariants of our term registry,
588 // a term is either equal to the empty string, or has length >= 1.
591 while (r
< card_need
&& success
)
593 Node rr
= nm
->mkConst(Rational(r
));
594 if (d_state
.areDisequal(rr
, lr
))
605 Trace("strings-card")
606 << "Symbolic length " << lr
<< " must be at least " << r
607 << " due to constant disequalities." << std::endl
;
609 needsSplit
= r
< card_need
;
614 // don't need to split
617 // first, try to split to merge equivalence classes
618 for (std::vector
<Node
>::iterator itr1
= cols
[i
].begin();
619 itr1
!= cols
[i
].end();
622 for (std::vector
<Node
>::iterator itr2
= itr1
+ 1; itr2
!= cols
[i
].end();
625 if (!d_state
.areDisequal(*itr1
, *itr2
))
628 if (d_im
.sendSplit(*itr1
, *itr2
, InferenceId::STRINGS_CARD_SP
))
635 // otherwise, we need a length constraint
636 uint32_t int_k
= static_cast<uint32_t>(card_need
);
637 EqcInfo
* ei
= d_state
.getOrMakeEqcInfo(lr
, true);
638 Trace("strings-card") << "Previous cardinality used for " << lr
<< " is "
639 << ((int)ei
->d_cardinalityLemK
.get() - 1)
641 if (int_k
+ 1 > ei
->d_cardinalityLemK
.get())
643 Node k_node
= nm
->mkConst(Rational(int_k
));
644 // add cardinality lemma
645 Node dist
= nm
->mkNode(DISTINCT
, cols
[i
]);
646 std::vector
<Node
> expn
;
647 expn
.push_back(dist
);
648 for (std::vector
<Node
>::iterator itr1
= cols
[i
].begin();
649 itr1
!= cols
[i
].end();
652 Node len
= nm
->mkNode(STRING_LENGTH
, *itr1
);
655 Node len_eq_lr
= len
.eqNode(lr
);
656 expn
.push_back(len_eq_lr
);
659 Node len
= nm
->mkNode(STRING_LENGTH
, cols
[i
][0]);
660 Node cons
= nm
->mkNode(GEQ
, len
, k_node
);
661 cons
= Rewriter::rewrite(cons
);
662 ei
->d_cardinalityLemK
.set(int_k
+ 1);
663 if (!cons
.isConst() || !cons
.getConst
<bool>())
666 expn
, expn
, cons
, InferenceId::STRINGS_CARDINALITY
, false, true);
671 Trace("strings-card") << "...end check cardinality" << std::endl
;
674 bool BaseSolver::isCongruent(Node n
)
676 return d_congruent
.find(n
) != d_congruent
.end();
679 Node
BaseSolver::getConstantEqc(Node eqc
)
681 std::map
<Node
, BaseEqcInfo
>::const_iterator it
= d_eqcInfo
.find(eqc
);
682 if (it
!= d_eqcInfo
.end() && it
->second
.d_bestContent
.isConst())
684 return it
->second
.d_bestContent
;
689 Node
BaseSolver::explainConstantEqc(Node n
, Node eqc
, std::vector
<Node
>& exp
)
691 std::map
<Node
, BaseEqcInfo
>::const_iterator it
= d_eqcInfo
.find(eqc
);
692 if (it
!= d_eqcInfo
.end())
694 BaseEqcInfo
& bei
= d_eqcInfo
[eqc
];
695 if (!bei
.d_bestContent
.isConst())
699 if (!bei
.d_exp
.isNull())
701 utils::flattenOp(AND
, bei
.d_exp
, exp
);
703 if (!bei
.d_base
.isNull())
705 d_im
.addToExplanation(n
, bei
.d_base
, exp
);
707 return bei
.d_bestContent
;
712 Node
BaseSolver::explainBestContentEqc(Node n
, Node eqc
, std::vector
<Node
>& exp
)
714 std::map
<Node
, BaseEqcInfo
>::const_iterator it
= d_eqcInfo
.find(eqc
);
715 if (it
!= d_eqcInfo
.end())
717 BaseEqcInfo
& bei
= d_eqcInfo
[eqc
];
718 Assert(!bei
.d_bestContent
.isNull());
719 if (!bei
.d_exp
.isNull())
721 utils::flattenOp(AND
, bei
.d_exp
, exp
);
723 if (!bei
.d_base
.isNull())
725 d_im
.addToExplanation(n
, bei
.d_base
, exp
);
727 return bei
.d_bestContent
;
733 const std::vector
<Node
>& BaseSolver::getStringEqc() const
738 Node
BaseSolver::TermIndex::add(TNode n
,
740 const SolverState
& s
,
742 std::vector
<Node
>& c
)
744 if (index
== n
.getNumChildren())
752 Assert(index
< n
.getNumChildren());
753 TNode nir
= s
.getRepresentative(n
[index
]);
754 // if it is empty, and doing CONCAT, ignore
755 if (nir
== er
&& n
.getKind() == STRING_CONCAT
)
757 return add(n
, index
+ 1, s
, er
, c
);
760 return d_children
[nir
].add(n
, index
+ 1, s
, er
, c
);
763 } // namespace strings
764 } // namespace theory