1 /********************* */
2 /*! \file ext_theory.cpp
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Tim King, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief Extended theory interface.
14 ** This implements a generic module, used by theory solvers, for performing
15 ** "context-dependent simplification", as described in Reynolds et al
16 ** "Designing Theory Solvers with Extensions", FroCoS 2017.
19 #include "theory/ext_theory.h"
21 #include "base/check.h"
22 #include "smt/smt_statistics_registry.h"
23 #include "theory/quantifiers_engine.h"
24 #include "theory/substitutions.h"
31 bool ExtTheoryCallback::getCurrentSubstitution(
33 const std::vector
<Node
>& vars
,
34 std::vector
<Node
>& subs
,
35 std::map
<Node
, std::vector
<Node
> >& exp
)
39 bool ExtTheoryCallback::isExtfReduced(int effort
,
42 std::vector
<Node
>& exp
)
46 bool ExtTheoryCallback::getReduction(int effort
,
54 ExtTheory::ExtTheory(ExtTheoryCallback
& p
,
56 context::UserContext
* u
,
66 d_cacheEnabled(cacheEnabled
)
68 d_true
= NodeManager::currentNM()->mkConst(true);
71 // Gets all leaf terms in n.
72 std::vector
<Node
> ExtTheory::collectVars(Node n
)
74 std::vector
<Node
> vars
;
75 std::set
<Node
> visited
;
76 std::vector
<Node
> worklist
;
77 worklist
.push_back(n
);
78 while (!worklist
.empty())
80 Node current
= worklist
.back();
82 if (current
.isConst() || visited
.count(current
) > 0)
86 visited
.insert(current
);
87 // Treat terms not belonging to this theory as leaf
88 // note : chould include terms not belonging to this theory
90 if (current
.getNumChildren() > 0)
92 worklist
.insert(worklist
.end(), current
.begin(), current
.end());
96 vars
.push_back(current
);
102 Node
ExtTheory::getSubstitutedTerm(int effort
,
104 std::vector
<Node
>& exp
,
109 Assert(d_gst_cache
[effort
].find(term
) != d_gst_cache
[effort
].end());
110 exp
.insert(exp
.end(),
111 d_gst_cache
[effort
][term
].d_exp
.begin(),
112 d_gst_cache
[effort
][term
].d_exp
.end());
113 return d_gst_cache
[effort
][term
].d_sterm
;
116 std::vector
<Node
> terms
;
117 terms
.push_back(term
);
118 std::vector
<Node
> sterms
;
119 std::vector
<std::vector
<Node
> > exps
;
120 getSubstitutedTerms(effort
, terms
, sterms
, exps
, useCache
);
121 Assert(sterms
.size() == 1);
122 Assert(exps
.size() == 1);
123 exp
.insert(exp
.end(), exps
[0].begin(), exps
[0].end());
128 void ExtTheory::getSubstitutedTerms(int effort
,
129 const std::vector
<Node
>& terms
,
130 std::vector
<Node
>& sterms
,
131 std::vector
<std::vector
<Node
> >& exp
,
136 for (const Node
& n
: terms
)
138 Assert(d_gst_cache
[effort
].find(n
) != d_gst_cache
[effort
].end());
139 sterms
.push_back(d_gst_cache
[effort
][n
].d_sterm
);
140 exp
.push_back(std::vector
<Node
>());
141 exp
[0].insert(exp
[0].end(),
142 d_gst_cache
[effort
][n
].d_exp
.begin(),
143 d_gst_cache
[effort
][n
].d_exp
.end());
148 Trace("extt-debug") << "getSubstitutedTerms for " << terms
.size() << " / "
149 << d_ext_func_terms
.size() << " extended functions."
153 // all variables we need to find a substitution for
154 std::vector
<Node
> vars
;
155 std::vector
<Node
> sub
;
156 std::map
<Node
, std::vector
<Node
> > expc
;
157 for (const Node
& n
: terms
)
159 // do substitution, rewrite
160 std::map
<Node
, ExtfInfo
>::iterator iti
= d_extf_info
.find(n
);
161 Assert(iti
!= d_extf_info
.end());
162 for (const Node
& v
: iti
->second
.d_vars
)
164 if (std::find(vars
.begin(), vars
.end(), v
) == vars
.end())
170 bool useSubs
= d_parent
.getCurrentSubstitution(effort
, vars
, sub
, expc
);
171 // get the current substitution for all variables
172 Assert(!useSubs
|| vars
.size() == sub
.size());
173 for (const Node
& n
: terms
)
176 std::vector
<Node
> expn
;
180 ns
= n
.substitute(vars
.begin(), vars
.end(), sub
.begin(), sub
.end());
183 // build explanation: explanation vars = sub for each vars in FV(n)
184 std::map
<Node
, ExtfInfo
>::iterator iti
= d_extf_info
.find(n
);
185 Assert(iti
!= d_extf_info
.end());
186 for (const Node
& v
: iti
->second
.d_vars
)
188 std::map
<Node
, std::vector
<Node
> >::iterator itx
= expc
.find(v
);
189 if (itx
!= expc
.end())
191 for (const Node
& e
: itx
->second
)
193 if (std::find(expn
.begin(), expn
.end(), e
) == expn
.end())
202 << " have " << n
<< " == " << ns
<< ", exp size=" << expn
.size()
206 sterms
.push_back(ns
);
211 d_gst_cache
[effort
][n
].d_sterm
= ns
;
212 d_gst_cache
[effort
][n
].d_exp
.clear();
213 d_gst_cache
[effort
][n
].d_exp
.insert(
214 d_gst_cache
[effort
][n
].d_exp
.end(), expn
.begin(), expn
.end());
221 bool ExtTheory::doInferencesInternal(int effort
,
222 const std::vector
<Node
>& terms
,
223 std::vector
<Node
>& nred
,
229 bool addedLemma
= false;
232 for (const Node
& n
: terms
)
235 // note: could do reduction with substitution here
237 if (!d_parent
.getReduction(effort
, n
, nr
, satDep
))
243 if (!nr
.isNull() && n
!= nr
)
245 Node lem
= NodeManager::currentNM()->mkNode(kind::EQUAL
, n
, nr
);
246 if (sendLemma(lem
, true))
249 << "ExtTheory : reduction lemma : " << lem
<< std::endl
;
253 markReduced(n
, satDep
);
259 std::vector
<Node
> sterms
;
260 std::vector
<std::vector
<Node
> > exp
;
261 getSubstitutedTerms(effort
, terms
, sterms
, exp
);
262 std::map
<Node
, unsigned> sterm_index
;
263 NodeManager
* nm
= NodeManager::currentNM();
264 for (unsigned i
= 0, size
= terms
.size(); i
< size
; i
++)
266 bool processed
= false;
267 if (sterms
[i
] != terms
[i
])
269 Node sr
= Rewriter::rewrite(sterms
[i
]);
270 // ask the theory if this term is reduced, e.g. is it constant or it
271 // is a non-extf term.
272 if (d_parent
.isExtfReduced(effort
, sr
, terms
[i
], exp
[i
]))
275 markReduced(terms
[i
]);
276 // We have exp[i] => terms[i] = sr, convert this to a clause.
277 // This ensures the proof infrastructure can process this as a
278 // normal theory lemma.
279 Node eq
= terms
[i
].eqNode(sr
);
283 std::vector
<Node
> eei
;
284 for (const Node
& e
: exp
[i
])
286 eei
.push_back(e
.negate());
289 lem
= nm
->mkNode(kind::OR
, eei
);
292 Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
293 << " by " << exp
[i
] << std::endl
;
294 Trace("extt-debug") << "...send lemma " << lem
<< std::endl
;
298 << "ExtTheory : substitution + rewrite lemma : " << lem
305 // check if we have already reduced this
306 std::map
<Node
, unsigned>::iterator itsi
= sterm_index
.find(sr
);
307 if (itsi
== sterm_index
.end())
313 // unsigned j = itsi->second;
314 // note : can add (non-reducing) lemma :
315 // exp[j] ^ exp[i] => sterms[i] = sterms[j]
318 Trace("extt-nred") << "Non-reduced term : " << sr
<< std::endl
;
323 Trace("extt-nred") << "Non-reduced term : " << sterms
[i
] << std::endl
;
327 nred
.push_back(terms
[i
]);
334 std::vector
<Node
> nnred
;
337 for (NodeBoolMap::iterator it
= d_ext_func_terms
.begin();
338 it
!= d_ext_func_terms
.end();
341 if ((*it
).second
&& !isContextIndependentInactive((*it
).first
))
343 std::vector
<Node
> nterms
;
344 nterms
.push_back((*it
).first
);
345 if (doInferencesInternal(effort
, nterms
, nnred
, true, isRed
))
354 for (const Node
& n
: terms
)
356 std::vector
<Node
> nterms
;
358 if (doInferencesInternal(effort
, nterms
, nnred
, true, isRed
))
367 bool ExtTheory::sendLemma(Node lem
, bool preprocess
)
371 if (d_pp_lemmas
.find(lem
) == d_pp_lemmas
.end())
373 d_pp_lemmas
.insert(lem
);
374 d_out
.lemma(lem
, LemmaProperty::PREPROCESS
);
380 if (d_lemmas
.find(lem
) == d_lemmas
.end())
382 d_lemmas
.insert(lem
);
390 bool ExtTheory::doInferences(int effort
,
391 const std::vector
<Node
>& terms
,
392 std::vector
<Node
>& nred
,
397 return doInferencesInternal(effort
, terms
, nred
, batch
, false);
402 bool ExtTheory::doInferences(int effort
, std::vector
<Node
>& nred
, bool batch
)
404 std::vector
<Node
> terms
= getActive();
405 return doInferencesInternal(effort
, terms
, nred
, batch
, false);
408 bool ExtTheory::doReductions(int effort
,
409 const std::vector
<Node
>& terms
,
410 std::vector
<Node
>& nred
,
415 return doInferencesInternal(effort
, terms
, nred
, batch
, true);
420 bool ExtTheory::doReductions(int effort
, std::vector
<Node
>& nred
, bool batch
)
422 const std::vector
<Node
> terms
= getActive();
423 return doInferencesInternal(effort
, terms
, nred
, batch
, true);
427 void ExtTheory::registerTerm(Node n
)
429 if (d_extf_kind
.find(n
.getKind()) != d_extf_kind
.end())
431 if (d_ext_func_terms
.find(n
) == d_ext_func_terms
.end())
433 Trace("extt-debug") << "Found extended function : " << n
<< std::endl
;
434 d_ext_func_terms
[n
] = true;
436 d_extf_info
[n
].d_vars
= collectVars(n
);
441 void ExtTheory::registerTermRec(Node n
)
443 std::unordered_set
<TNode
, TNodeHashFunction
> visited
;
444 std::vector
<TNode
> visit
;
451 if (visited
.find(cur
) == visited
.end())
455 for (const Node
& cc
: cur
)
460 } while (!visit
.empty());
464 void ExtTheory::markReduced(Node n
, bool satDep
)
466 Trace("extt-debug") << "Mark reduced " << n
<< std::endl
;
468 Assert(d_ext_func_terms
.find(n
) != d_ext_func_terms
.end());
469 d_ext_func_terms
[n
] = false;
472 d_ci_inactive
.insert(n
);
476 if (d_has_extf
.get() == n
)
478 for (NodeBoolMap::iterator it
= d_ext_func_terms
.begin();
479 it
!= d_ext_func_terms
.end();
482 // if not already reduced
483 if ((*it
).second
&& !isContextIndependentInactive((*it
).first
))
485 d_has_extf
= (*it
).first
;
492 void ExtTheory::markCongruent(Node a
, Node b
)
494 Trace("extt-debug") << "Mark congruent : " << a
<< " " << b
<< std::endl
;
497 NodeBoolMap::const_iterator it
= d_ext_func_terms
.find(b
);
498 if (it
!= d_ext_func_terms
.end())
500 if (d_ext_func_terms
.find(a
) != d_ext_func_terms
.end())
502 d_ext_func_terms
[a
] = d_ext_func_terms
[a
] && (*it
).second
;
508 d_ext_func_terms
[b
] = false;
516 bool ExtTheory::isContextIndependentInactive(Node n
) const
518 return d_ci_inactive
.find(n
) != d_ci_inactive
.end();
521 void ExtTheory::getTerms(std::vector
<Node
>& terms
)
523 for (NodeBoolMap::iterator it
= d_ext_func_terms
.begin();
524 it
!= d_ext_func_terms
.end();
527 terms
.push_back((*it
).first
);
531 bool ExtTheory::hasActiveTerm() const { return !d_has_extf
.get().isNull(); }
534 bool ExtTheory::isActive(Node n
) const
536 NodeBoolMap::const_iterator it
= d_ext_func_terms
.find(n
);
537 if (it
!= d_ext_func_terms
.end())
539 return (*it
).second
&& !isContextIndependentInactive(n
);
545 std::vector
<Node
> ExtTheory::getActive() const
547 std::vector
<Node
> active
;
548 for (NodeBoolMap::iterator it
= d_ext_func_terms
.begin();
549 it
!= d_ext_func_terms
.end();
552 // if not already reduced
553 if ((*it
).second
&& !isContextIndependentInactive((*it
).first
))
555 active
.push_back((*it
).first
);
561 std::vector
<Node
> ExtTheory::getActive(Kind k
) const
563 std::vector
<Node
> active
;
564 for (NodeBoolMap::iterator it
= d_ext_func_terms
.begin();
565 it
!= d_ext_func_terms
.end();
568 // if not already reduced
569 if ((*it
).first
.getKind() == k
&& (*it
).second
570 && !isContextIndependentInactive((*it
).first
))
572 active
.push_back((*it
).first
);
578 void ExtTheory::clearCache() { d_gst_cache
.clear(); }
580 } /* CVC4::theory namespace */
581 } /* CVC4 namespace */