bv: Refactor getEqualityStatus and use for both bitblasting solvers. (#6933)
[cvc5.git] / src / theory / bv / theory_bv.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Mathias Preiner, Andrew Reynolds, Haniel Barbosa
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Theory of bit-vectors.
14 */
15
16 #include "theory/bv/theory_bv.h"
17
18 #include "options/bv_options.h"
19 #include "options/smt_options.h"
20 #include "proof/proof_checker.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/bv/bv_solver_bitblast.h"
23 #include "theory/bv/bv_solver_bitblast_internal.h"
24 #include "theory/bv/bv_solver_layered.h"
25 #include "theory/bv/theory_bv_utils.h"
26 #include "theory/ee_setup_info.h"
27 #include "theory/trust_substitutions.h"
28
29 namespace cvc5 {
30 namespace theory {
31 namespace bv {
32
33 TheoryBV::TheoryBV(context::Context* c,
34 context::UserContext* u,
35 OutputChannel& out,
36 Valuation valuation,
37 const LogicInfo& logicInfo,
38 ProofNodeManager* pnm,
39 std::string name)
40 : Theory(THEORY_BV, c, u, out, valuation, logicInfo, pnm, name),
41 d_internal(nullptr),
42 d_rewriter(),
43 d_state(c, u, valuation),
44 d_im(*this, d_state, nullptr, "theory::bv::"),
45 d_notify(d_im),
46 d_invalidateModelCache(c, true),
47 d_stats("theory::bv::")
48 {
49 switch (options::bvSolver())
50 {
51 case options::BVSolver::BITBLAST:
52 d_internal.reset(new BVSolverBitblast(&d_state, d_im, pnm));
53 break;
54
55 case options::BVSolver::LAYERED:
56 d_internal.reset(new BVSolverLayered(*this, c, u, pnm, name));
57 break;
58
59 default:
60 AlwaysAssert(options::bvSolver() == options::BVSolver::BITBLAST_INTERNAL);
61 d_internal.reset(new BVSolverBitblastInternal(&d_state, d_im, pnm));
62 }
63 d_theoryState = &d_state;
64 d_inferManager = &d_im;
65 }
66
67 TheoryBV::~TheoryBV() {}
68
69 TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; }
70
71 ProofRuleChecker* TheoryBV::getProofChecker()
72 {
73 if (options::bvSolver() == options::BVSolver::BITBLAST_INTERNAL)
74 {
75 return static_cast<BVSolverBitblastInternal*>(d_internal.get())
76 ->getProofChecker();
77 }
78 return nullptr;
79 }
80
81 bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi)
82 {
83 bool need_ee = d_internal->needsEqualityEngine(esi);
84
85 /* Set up default notify class for equality engine. */
86 if (need_ee && esi.d_notify == nullptr)
87 {
88 esi.d_notify = &d_notify;
89 esi.d_name = "theory::bv::ee";
90 }
91
92 return need_ee;
93 }
94
95 void TheoryBV::finishInit()
96 {
97 // these kinds are semi-evaluated in getModelValue (applications of this
98 // kind are treated as variables)
99 getValuation().setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV);
100 getValuation().setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM);
101 d_internal->finishInit();
102
103 eq::EqualityEngine* ee = getEqualityEngine();
104 if (ee)
105 {
106 // The kinds we are treating as function application in congruence
107 ee->addFunctionKind(kind::BITVECTOR_CONCAT, true);
108 // ee->addFunctionKind(kind::BITVECTOR_AND);
109 // ee->addFunctionKind(kind::BITVECTOR_OR);
110 // ee->addFunctionKind(kind::BITVECTOR_XOR);
111 // ee->addFunctionKind(kind::BITVECTOR_NOT);
112 // ee->addFunctionKind(kind::BITVECTOR_NAND);
113 // ee->addFunctionKind(kind::BITVECTOR_NOR);
114 // ee->addFunctionKind(kind::BITVECTOR_XNOR);
115 // ee->addFunctionKind(kind::BITVECTOR_COMP);
116 ee->addFunctionKind(kind::BITVECTOR_MULT, true);
117 ee->addFunctionKind(kind::BITVECTOR_ADD, true);
118 ee->addFunctionKind(kind::BITVECTOR_EXTRACT, true);
119 // ee->addFunctionKind(kind::BITVECTOR_SUB);
120 // ee->addFunctionKind(kind::BITVECTOR_NEG);
121 // ee->addFunctionKind(kind::BITVECTOR_UDIV);
122 // ee->addFunctionKind(kind::BITVECTOR_UREM);
123 // ee->addFunctionKind(kind::BITVECTOR_SDIV);
124 // ee->addFunctionKind(kind::BITVECTOR_SREM);
125 // ee->addFunctionKind(kind::BITVECTOR_SMOD);
126 // ee->addFunctionKind(kind::BITVECTOR_SHL);
127 // ee->addFunctionKind(kind::BITVECTOR_LSHR);
128 // ee->addFunctionKind(kind::BITVECTOR_ASHR);
129 // ee->addFunctionKind(kind::BITVECTOR_ULT);
130 // ee->addFunctionKind(kind::BITVECTOR_ULE);
131 // ee->addFunctionKind(kind::BITVECTOR_UGT);
132 // ee->addFunctionKind(kind::BITVECTOR_UGE);
133 // ee->addFunctionKind(kind::BITVECTOR_SLT);
134 // ee->addFunctionKind(kind::BITVECTOR_SLE);
135 // ee->addFunctionKind(kind::BITVECTOR_SGT);
136 // ee->addFunctionKind(kind::BITVECTOR_SGE);
137 ee->addFunctionKind(kind::BITVECTOR_TO_NAT);
138 ee->addFunctionKind(kind::INT_TO_BITVECTOR);
139 }
140 }
141
142 void TheoryBV::preRegisterTerm(TNode node)
143 {
144 d_internal->preRegisterTerm(node);
145
146 eq::EqualityEngine* ee = getEqualityEngine();
147 if (ee)
148 {
149 if (node.getKind() == kind::EQUAL)
150 {
151 ee->addTriggerPredicate(node);
152 }
153 else
154 {
155 ee->addTerm(node);
156 }
157 }
158 }
159
160 bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); }
161
162 void TheoryBV::postCheck(Effort e)
163 {
164 d_invalidateModelCache = true;
165 d_internal->postCheck(e);
166 }
167
168 bool TheoryBV::preNotifyFact(
169 TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
170 {
171 return d_internal->preNotifyFact(atom, pol, fact, isPrereg, isInternal);
172 }
173
174 void TheoryBV::notifyFact(TNode atom, bool pol, TNode fact, bool isInternal)
175 {
176 d_internal->notifyFact(atom, pol, fact, isInternal);
177 }
178
179 bool TheoryBV::needsCheckLastEffort()
180 {
181 return d_internal->needsCheckLastEffort();
182 }
183
184 void TheoryBV::computeRelevantTerms(std::set<Node>& termSet)
185 {
186 return d_internal->computeRelevantTerms(termSet);
187 }
188
189 bool TheoryBV::collectModelValues(TheoryModel* m, const std::set<Node>& termSet)
190 {
191 return d_internal->collectModelValues(m, termSet);
192 }
193
194 void TheoryBV::propagate(Effort e) { return d_internal->propagate(e); }
195
196 Theory::PPAssertStatus TheoryBV::ppAssert(
197 TrustNode tin, TrustSubstitutionMap& outSubstitutions)
198 {
199 TNode in = tin.getNode();
200 Kind k = in.getKind();
201 if (k == kind::EQUAL)
202 {
203 // Substitute variables
204 if (in[0].isVar() && isLegalElimination(in[0], in[1]))
205 {
206 ++d_stats.d_solveSubstitutions;
207 outSubstitutions.addSubstitutionSolved(in[0], in[1], tin);
208 return Theory::PP_ASSERT_STATUS_SOLVED;
209 }
210 if (in[1].isVar() && isLegalElimination(in[1], in[0]))
211 {
212 ++d_stats.d_solveSubstitutions;
213 outSubstitutions.addSubstitutionSolved(in[1], in[0], tin);
214 return Theory::PP_ASSERT_STATUS_SOLVED;
215 }
216 /**
217 * Eliminate extract over bit-vector variables.
218 *
219 * Given x[h:l] = c, where c is a constant and x is a variable.
220 *
221 * We rewrite to:
222 *
223 * x = sk1::c if l == 0, where bw(sk1) = bw(x)-1-h
224 * x = c::sk2 if h == bw(x)-1, where bw(sk2) = l
225 * x = sk1::c::sk2 otherwise, where bw(sk1) = bw(x)-1-h and bw(sk2) = l
226 */
227 Node node = Rewriter::rewrite(in);
228 if ((node[0].getKind() == kind::BITVECTOR_EXTRACT && node[1].isConst())
229 || (node[1].getKind() == kind::BITVECTOR_EXTRACT
230 && node[0].isConst()))
231 {
232 Node extract = node[0].isConst() ? node[1] : node[0];
233 if (extract[0].isVar())
234 {
235 Node c = node[0].isConst() ? node[0] : node[1];
236
237 uint32_t high = utils::getExtractHigh(extract);
238 uint32_t low = utils::getExtractLow(extract);
239 uint32_t var_bw = utils::getSize(extract[0]);
240 std::vector<Node> children;
241
242 // create sk1 with size bw(x)-1-h
243 if (low == 0 || high != var_bw - 1)
244 {
245 Assert(high != var_bw - 1);
246 uint32_t skolem_size = var_bw - high - 1;
247 Node skolem = utils::mkVar(skolem_size);
248 children.push_back(skolem);
249 }
250
251 children.push_back(c);
252
253 // create sk2 with size l
254 if (high == var_bw - 1 || low != 0)
255 {
256 Assert(low != 0);
257 uint32_t skolem_size = low;
258 Node skolem = utils::mkVar(skolem_size);
259 children.push_back(skolem);
260 }
261
262 Node concat = utils::mkConcat(children);
263 Assert(utils::getSize(concat) == utils::getSize(extract[0]));
264 if (isLegalElimination(extract[0], concat))
265 {
266 outSubstitutions.addSubstitutionSolved(extract[0], concat, tin);
267 return Theory::PP_ASSERT_STATUS_SOLVED;
268 }
269 }
270 }
271 }
272 return Theory::PP_ASSERT_STATUS_UNSOLVED;
273 }
274
275 TrustNode TheoryBV::ppRewrite(TNode t, std::vector<SkolemLemma>& lems)
276 {
277 // first, see if we need to expand definitions
278 TrustNode texp = d_rewriter.expandDefinition(t);
279 if (!texp.isNull())
280 {
281 return texp;
282 }
283 return d_internal->ppRewrite(t);
284 }
285
286 void TheoryBV::presolve() { d_internal->presolve(); }
287
288 EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b)
289 {
290 EqualityStatus status = d_internal->getEqualityStatus(a, b);
291
292 if (status == EqualityStatus::EQUALITY_UNKNOWN)
293 {
294 Node value_a = getValue(a);
295 Node value_b = getValue(b);
296
297 if (value_a.isNull() || value_b.isNull())
298 {
299 return status;
300 }
301
302 if (value_a == value_b)
303 {
304 Debug("theory-bv") << EQUALITY_TRUE_IN_MODEL << std::endl;
305 return EQUALITY_TRUE_IN_MODEL;
306 }
307 Debug("theory-bv") << EQUALITY_FALSE_IN_MODEL << std::endl;
308 return EQUALITY_FALSE_IN_MODEL;
309 }
310 return status;
311 }
312
313 TrustNode TheoryBV::explain(TNode node) { return d_internal->explain(node); }
314
315 void TheoryBV::notifySharedTerm(TNode t)
316 {
317 d_internal->notifySharedTerm(t);
318 }
319
320 void TheoryBV::ppStaticLearn(TNode in, NodeBuilder& learned)
321 {
322 d_internal->ppStaticLearn(in, learned);
323 }
324
325 bool TheoryBV::applyAbstraction(const std::vector<Node>& assertions,
326 std::vector<Node>& new_assertions)
327 {
328 return d_internal->applyAbstraction(assertions, new_assertions);
329 }
330
331 Node TheoryBV::getValue(TNode node)
332 {
333 if (d_invalidateModelCache.get())
334 {
335 d_modelCache.clear();
336 }
337 d_invalidateModelCache.set(false);
338
339 std::vector<TNode> visit;
340
341 TNode cur;
342 visit.push_back(node);
343 do
344 {
345 cur = visit.back();
346 visit.pop_back();
347
348 auto it = d_modelCache.find(cur);
349 if (it != d_modelCache.end() && !it->second.isNull())
350 {
351 continue;
352 }
353
354 if (cur.isConst())
355 {
356 d_modelCache[cur] = cur;
357 continue;
358 }
359
360 Node value = d_internal->getValue(cur, false);
361 if (value.isConst())
362 {
363 d_modelCache[cur] = value;
364 continue;
365 }
366
367 if (Theory::isLeafOf(cur, theory::THEORY_BV))
368 {
369 value = d_internal->getValue(cur, true);
370 d_modelCache[cur] = value;
371 continue;
372 }
373
374 if (it == d_modelCache.end())
375 {
376 visit.push_back(cur);
377 d_modelCache.emplace(cur, Node());
378 visit.insert(visit.end(), cur.begin(), cur.end());
379 }
380 else if (it->second.isNull())
381 {
382 NodeBuilder nb(cur.getKind());
383 if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
384 {
385 nb << cur.getOperator();
386 }
387
388 std::unordered_map<Node, Node>::iterator iit;
389 for (const TNode& child : cur)
390 {
391 iit = d_modelCache.find(child);
392 Assert(iit != d_modelCache.end());
393 Assert(iit->second.isConst());
394 nb << iit->second;
395 }
396 it->second = Rewriter::rewrite(nb.constructNode());
397 }
398 } while (!visit.empty());
399
400 auto it = d_modelCache.find(node);
401 Assert(it != d_modelCache.end());
402 return it->second;
403 }
404
405 TheoryBV::Statistics::Statistics(const std::string& name)
406 : d_solveSubstitutions(
407 smtStatisticsRegistry().registerInt(name + "NumSolveSubstitutions"))
408 {
409 }
410
411 } // namespace bv
412 } // namespace theory
413 } // namespace cvc5