FloatingPoint: Separate out symFPU glue code. (#5492)
[cvc5.git] / src / theory / ext_theory.cpp
1 /********************* */
2 /*! \file ext_theory.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Tim King, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Extended theory interface.
13 **
14 ** This implements a generic module, used by theory solvers, for performing
15 ** "context-dependent simplification", as described in Reynolds et al
16 ** "Designing Theory Solvers with Extensions", FroCoS 2017.
17 **/
18
19 #include "theory/ext_theory.h"
20
21 #include "base/check.h"
22 #include "smt/smt_statistics_registry.h"
23 #include "theory/quantifiers_engine.h"
24 #include "theory/substitutions.h"
25
26 using namespace std;
27
28 namespace CVC4 {
29 namespace theory {
30
31 bool ExtTheoryCallback::getCurrentSubstitution(
32 int effort,
33 const std::vector<Node>& vars,
34 std::vector<Node>& subs,
35 std::map<Node, std::vector<Node> >& exp)
36 {
37 return false;
38 }
39 bool ExtTheoryCallback::isExtfReduced(int effort,
40 Node n,
41 Node on,
42 std::vector<Node>& exp)
43 {
44 return n.isConst();
45 }
46 bool ExtTheoryCallback::getReduction(int effort,
47 Node n,
48 Node& nr,
49 bool& isSatDep)
50 {
51 return false;
52 }
53
54 ExtTheory::ExtTheory(ExtTheoryCallback& p,
55 context::Context* c,
56 context::UserContext* u,
57 OutputChannel& out,
58 bool cacheEnabled)
59 : d_parent(p),
60 d_out(out),
61 d_ext_func_terms(c),
62 d_ci_inactive(u),
63 d_has_extf(c),
64 d_lemmas(u),
65 d_pp_lemmas(u),
66 d_cacheEnabled(cacheEnabled)
67 {
68 d_true = NodeManager::currentNM()->mkConst(true);
69 }
70
71 // Gets all leaf terms in n.
72 std::vector<Node> ExtTheory::collectVars(Node n)
73 {
74 std::vector<Node> vars;
75 std::set<Node> visited;
76 std::vector<Node> worklist;
77 worklist.push_back(n);
78 while (!worklist.empty())
79 {
80 Node current = worklist.back();
81 worklist.pop_back();
82 if (current.isConst() || visited.count(current) > 0)
83 {
84 continue;
85 }
86 visited.insert(current);
87 // Treat terms not belonging to this theory as leaf
88 // note : chould include terms not belonging to this theory
89 // (commented below)
90 if (current.getNumChildren() > 0)
91 {
92 worklist.insert(worklist.end(), current.begin(), current.end());
93 }
94 else
95 {
96 vars.push_back(current);
97 }
98 }
99 return vars;
100 }
101
102 Node ExtTheory::getSubstitutedTerm(int effort,
103 Node term,
104 std::vector<Node>& exp,
105 bool useCache)
106 {
107 if (useCache)
108 {
109 Assert(d_gst_cache[effort].find(term) != d_gst_cache[effort].end());
110 exp.insert(exp.end(),
111 d_gst_cache[effort][term].d_exp.begin(),
112 d_gst_cache[effort][term].d_exp.end());
113 return d_gst_cache[effort][term].d_sterm;
114 }
115
116 std::vector<Node> terms;
117 terms.push_back(term);
118 std::vector<Node> sterms;
119 std::vector<std::vector<Node> > exps;
120 getSubstitutedTerms(effort, terms, sterms, exps, useCache);
121 Assert(sterms.size() == 1);
122 Assert(exps.size() == 1);
123 exp.insert(exp.end(), exps[0].begin(), exps[0].end());
124 return sterms[0];
125 }
126
127 // do inferences
128 void ExtTheory::getSubstitutedTerms(int effort,
129 const std::vector<Node>& terms,
130 std::vector<Node>& sterms,
131 std::vector<std::vector<Node> >& exp,
132 bool useCache)
133 {
134 if (useCache)
135 {
136 for (const Node& n : terms)
137 {
138 Assert(d_gst_cache[effort].find(n) != d_gst_cache[effort].end());
139 sterms.push_back(d_gst_cache[effort][n].d_sterm);
140 exp.push_back(std::vector<Node>());
141 exp[0].insert(exp[0].end(),
142 d_gst_cache[effort][n].d_exp.begin(),
143 d_gst_cache[effort][n].d_exp.end());
144 }
145 }
146 else
147 {
148 Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
149 << d_ext_func_terms.size() << " extended functions."
150 << std::endl;
151 if (!terms.empty())
152 {
153 // all variables we need to find a substitution for
154 std::vector<Node> vars;
155 std::vector<Node> sub;
156 std::map<Node, std::vector<Node> > expc;
157 for (const Node& n : terms)
158 {
159 // do substitution, rewrite
160 std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
161 Assert(iti != d_extf_info.end());
162 for (const Node& v : iti->second.d_vars)
163 {
164 if (std::find(vars.begin(), vars.end(), v) == vars.end())
165 {
166 vars.push_back(v);
167 }
168 }
169 }
170 bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc);
171 // get the current substitution for all variables
172 Assert(!useSubs || vars.size() == sub.size());
173 for (const Node& n : terms)
174 {
175 Node ns = n;
176 std::vector<Node> expn;
177 if (useSubs)
178 {
179 // do substitution
180 ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
181 if (ns != n)
182 {
183 // build explanation: explanation vars = sub for each vars in FV(n)
184 std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
185 Assert(iti != d_extf_info.end());
186 for (const Node& v : iti->second.d_vars)
187 {
188 std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
189 if (itx != expc.end())
190 {
191 for (const Node& e : itx->second)
192 {
193 if (std::find(expn.begin(), expn.end(), e) == expn.end())
194 {
195 expn.push_back(e);
196 }
197 }
198 }
199 }
200 }
201 Trace("extt-debug")
202 << " have " << n << " == " << ns << ", exp size=" << expn.size()
203 << "." << std::endl;
204 }
205 // add to vector
206 sterms.push_back(ns);
207 exp.push_back(expn);
208 // add to cache
209 if (d_cacheEnabled)
210 {
211 d_gst_cache[effort][n].d_sterm = ns;
212 d_gst_cache[effort][n].d_exp.clear();
213 d_gst_cache[effort][n].d_exp.insert(
214 d_gst_cache[effort][n].d_exp.end(), expn.begin(), expn.end());
215 }
216 }
217 }
218 }
219 }
220
221 bool ExtTheory::doInferencesInternal(int effort,
222 const std::vector<Node>& terms,
223 std::vector<Node>& nred,
224 bool batch,
225 bool isRed)
226 {
227 if (batch)
228 {
229 bool addedLemma = false;
230 if (isRed)
231 {
232 for (const Node& n : terms)
233 {
234 Node nr;
235 // note: could do reduction with substitution here
236 bool satDep = false;
237 if (!d_parent.getReduction(effort, n, nr, satDep))
238 {
239 nred.push_back(n);
240 }
241 else
242 {
243 if (!nr.isNull() && n != nr)
244 {
245 Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
246 if (sendLemma(lem, true))
247 {
248 Trace("extt-lemma")
249 << "ExtTheory : reduction lemma : " << lem << std::endl;
250 addedLemma = true;
251 }
252 }
253 markReduced(n, satDep);
254 }
255 }
256 }
257 else
258 {
259 std::vector<Node> sterms;
260 std::vector<std::vector<Node> > exp;
261 getSubstitutedTerms(effort, terms, sterms, exp);
262 std::map<Node, unsigned> sterm_index;
263 NodeManager* nm = NodeManager::currentNM();
264 for (unsigned i = 0, size = terms.size(); i < size; i++)
265 {
266 bool processed = false;
267 if (sterms[i] != terms[i])
268 {
269 Node sr = Rewriter::rewrite(sterms[i]);
270 // ask the theory if this term is reduced, e.g. is it constant or it
271 // is a non-extf term.
272 if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i]))
273 {
274 processed = true;
275 markReduced(terms[i]);
276 // We have exp[i] => terms[i] = sr, convert this to a clause.
277 // This ensures the proof infrastructure can process this as a
278 // normal theory lemma.
279 Node eq = terms[i].eqNode(sr);
280 Node lem = eq;
281 if (!exp[i].empty())
282 {
283 std::vector<Node> eei;
284 for (const Node& e : exp[i])
285 {
286 eei.push_back(e.negate());
287 }
288 eei.push_back(eq);
289 lem = nm->mkNode(kind::OR, eei);
290 }
291
292 Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
293 << " by " << exp[i] << std::endl;
294 Trace("extt-debug") << "...send lemma " << lem << std::endl;
295 if (sendLemma(lem))
296 {
297 Trace("extt-lemma")
298 << "ExtTheory : substitution + rewrite lemma : " << lem
299 << std::endl;
300 addedLemma = true;
301 }
302 }
303 else
304 {
305 // check if we have already reduced this
306 std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
307 if (itsi == sterm_index.end())
308 {
309 sterm_index[sr] = i;
310 }
311 else
312 {
313 // unsigned j = itsi->second;
314 // note : can add (non-reducing) lemma :
315 // exp[j] ^ exp[i] => sterms[i] = sterms[j]
316 }
317
318 Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
319 }
320 }
321 else
322 {
323 Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
324 }
325 if (!processed)
326 {
327 nred.push_back(terms[i]);
328 }
329 }
330 }
331 return addedLemma;
332 }
333 // non-batch
334 std::vector<Node> nnred;
335 if (terms.empty())
336 {
337 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
338 it != d_ext_func_terms.end();
339 ++it)
340 {
341 if ((*it).second && !isContextIndependentInactive((*it).first))
342 {
343 std::vector<Node> nterms;
344 nterms.push_back((*it).first);
345 if (doInferencesInternal(effort, nterms, nnred, true, isRed))
346 {
347 return true;
348 }
349 }
350 }
351 }
352 else
353 {
354 for (const Node& n : terms)
355 {
356 std::vector<Node> nterms;
357 nterms.push_back(n);
358 if (doInferencesInternal(effort, nterms, nnred, true, isRed))
359 {
360 return true;
361 }
362 }
363 }
364 return false;
365 }
366
367 bool ExtTheory::sendLemma(Node lem, bool preprocess)
368 {
369 if (preprocess)
370 {
371 if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
372 {
373 d_pp_lemmas.insert(lem);
374 d_out.lemma(lem, LemmaProperty::PREPROCESS);
375 return true;
376 }
377 }
378 else
379 {
380 if (d_lemmas.find(lem) == d_lemmas.end())
381 {
382 d_lemmas.insert(lem);
383 d_out.lemma(lem);
384 return true;
385 }
386 }
387 return false;
388 }
389
390 bool ExtTheory::doInferences(int effort,
391 const std::vector<Node>& terms,
392 std::vector<Node>& nred,
393 bool batch)
394 {
395 if (!terms.empty())
396 {
397 return doInferencesInternal(effort, terms, nred, batch, false);
398 }
399 return false;
400 }
401
402 bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
403 {
404 std::vector<Node> terms = getActive();
405 return doInferencesInternal(effort, terms, nred, batch, false);
406 }
407
408 bool ExtTheory::doReductions(int effort,
409 const std::vector<Node>& terms,
410 std::vector<Node>& nred,
411 bool batch)
412 {
413 if (!terms.empty())
414 {
415 return doInferencesInternal(effort, terms, nred, batch, true);
416 }
417 return false;
418 }
419
420 bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
421 {
422 const std::vector<Node> terms = getActive();
423 return doInferencesInternal(effort, terms, nred, batch, true);
424 }
425
426 // Register term.
427 void ExtTheory::registerTerm(Node n)
428 {
429 if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
430 {
431 if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
432 {
433 Trace("extt-debug") << "Found extended function : " << n << std::endl;
434 d_ext_func_terms[n] = true;
435 d_has_extf = n;
436 d_extf_info[n].d_vars = collectVars(n);
437 }
438 }
439 }
440
441 void ExtTheory::registerTermRec(Node n)
442 {
443 std::unordered_set<TNode, TNodeHashFunction> visited;
444 std::vector<TNode> visit;
445 TNode cur;
446 visit.push_back(n);
447 do
448 {
449 cur = visit.back();
450 visit.pop_back();
451 if (visited.find(cur) == visited.end())
452 {
453 visited.insert(cur);
454 registerTerm(cur);
455 for (const Node& cc : cur)
456 {
457 visit.push_back(cc);
458 }
459 }
460 } while (!visit.empty());
461 }
462
463 // mark reduced
464 void ExtTheory::markReduced(Node n, bool satDep)
465 {
466 Trace("extt-debug") << "Mark reduced " << n << std::endl;
467 registerTerm(n);
468 Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
469 d_ext_func_terms[n] = false;
470 if (!satDep)
471 {
472 d_ci_inactive.insert(n);
473 }
474
475 // update has_extf
476 if (d_has_extf.get() == n)
477 {
478 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
479 it != d_ext_func_terms.end();
480 ++it)
481 {
482 // if not already reduced
483 if ((*it).second && !isContextIndependentInactive((*it).first))
484 {
485 d_has_extf = (*it).first;
486 }
487 }
488 }
489 }
490
491 // mark congruent
492 void ExtTheory::markCongruent(Node a, Node b)
493 {
494 Trace("extt-debug") << "Mark congruent : " << a << " " << b << std::endl;
495 registerTerm(a);
496 registerTerm(b);
497 NodeBoolMap::const_iterator it = d_ext_func_terms.find(b);
498 if (it != d_ext_func_terms.end())
499 {
500 if (d_ext_func_terms.find(a) != d_ext_func_terms.end())
501 {
502 d_ext_func_terms[a] = d_ext_func_terms[a] && (*it).second;
503 }
504 else
505 {
506 Assert(false);
507 }
508 d_ext_func_terms[b] = false;
509 }
510 else
511 {
512 Assert(false);
513 }
514 }
515
516 bool ExtTheory::isContextIndependentInactive(Node n) const
517 {
518 return d_ci_inactive.find(n) != d_ci_inactive.end();
519 }
520
521 void ExtTheory::getTerms(std::vector<Node>& terms)
522 {
523 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
524 it != d_ext_func_terms.end();
525 ++it)
526 {
527 terms.push_back((*it).first);
528 }
529 }
530
531 bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
532
533 // is active
534 bool ExtTheory::isActive(Node n) const
535 {
536 NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
537 if (it != d_ext_func_terms.end())
538 {
539 return (*it).second && !isContextIndependentInactive(n);
540 }
541 return false;
542 }
543
544 // get active
545 std::vector<Node> ExtTheory::getActive() const
546 {
547 std::vector<Node> active;
548 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
549 it != d_ext_func_terms.end();
550 ++it)
551 {
552 // if not already reduced
553 if ((*it).second && !isContextIndependentInactive((*it).first))
554 {
555 active.push_back((*it).first);
556 }
557 }
558 return active;
559 }
560
561 std::vector<Node> ExtTheory::getActive(Kind k) const
562 {
563 std::vector<Node> active;
564 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
565 it != d_ext_func_terms.end();
566 ++it)
567 {
568 // if not already reduced
569 if ((*it).first.getKind() == k && (*it).second
570 && !isContextIndependentInactive((*it).first))
571 {
572 active.push_back((*it).first);
573 }
574 }
575 return active;
576 }
577
578 void ExtTheory::clearCache() { d_gst_cache.clear(); }
579
580 } /* CVC4::theory namespace */
581 } /* CVC4 namespace */