Improve arithmetic proofs (#6106)
[cvc5.git] / src / expr / term_conversion_proof_generator.cpp
1 /********************* */
2 /*! \file term_conversion_proof_generator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of term conversion proof generator utility
13 **/
14
15 #include "expr/term_conversion_proof_generator.h"
16
17 #include "expr/proof_checker.h"
18 #include "expr/proof_node.h"
19 #include "expr/term_context.h"
20 #include "expr/term_context_stack.h"
21
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25
26 std::ostream& operator<<(std::ostream& out, TConvPolicy tcpol)
27 {
28 switch (tcpol)
29 {
30 case TConvPolicy::FIXPOINT: out << "FIXPOINT"; break;
31 case TConvPolicy::ONCE: out << "ONCE"; break;
32 default: out << "TConvPolicy:unknown"; break;
33 }
34 return out;
35 }
36
37 std::ostream& operator<<(std::ostream& out, TConvCachePolicy tcpol)
38 {
39 switch (tcpol)
40 {
41 case TConvCachePolicy::STATIC: out << "STATIC"; break;
42 case TConvCachePolicy::DYNAMIC: out << "DYNAMIC"; break;
43 case TConvCachePolicy::NEVER: out << "NEVER"; break;
44 default: out << "TConvCachePolicy:unknown"; break;
45 }
46 return out;
47 }
48
49 TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm,
50 context::Context* c,
51 TConvPolicy pol,
52 TConvCachePolicy cpol,
53 std::string name,
54 TermContext* tccb,
55 bool rewriteOps)
56 : d_proof(pnm, nullptr, c, name + "::LazyCDProof"),
57 d_preRewriteMap(c ? c : &d_context),
58 d_postRewriteMap(c ? c : &d_context),
59 d_policy(pol),
60 d_cpolicy(cpol),
61 d_name(name),
62 d_tcontext(tccb),
63 d_rewriteOps(rewriteOps)
64 {
65 }
66
67 TConvProofGenerator::~TConvProofGenerator() {}
68
69 void TConvProofGenerator::addRewriteStep(Node t,
70 Node s,
71 ProofGenerator* pg,
72 bool isPre,
73 PfRule trustId,
74 bool isClosed,
75 uint32_t tctx)
76 {
77 Node eq = registerRewriteStep(t, s, tctx, isPre);
78 if (!eq.isNull())
79 {
80 d_proof.addLazyStep(eq, pg, trustId, isClosed);
81 }
82 }
83
84 void TConvProofGenerator::addRewriteStep(
85 Node t, Node s, ProofStep ps, bool isPre, uint32_t tctx)
86 {
87 Node eq = registerRewriteStep(t, s, tctx, isPre);
88 if (!eq.isNull())
89 {
90 d_proof.addStep(eq, ps);
91 }
92 }
93
94 void TConvProofGenerator::addRewriteStep(Node t,
95 Node s,
96 PfRule id,
97 const std::vector<Node>& children,
98 const std::vector<Node>& args,
99 bool isPre,
100 uint32_t tctx)
101 {
102 Node eq = registerRewriteStep(t, s, tctx, isPre);
103 if (!eq.isNull())
104 {
105 d_proof.addStep(eq, id, children, args);
106 }
107 }
108
109 bool TConvProofGenerator::hasRewriteStep(Node t,
110 uint32_t tctx,
111 bool isPre) const
112 {
113 return !getRewriteStep(t, tctx, isPre).isNull();
114 }
115
116 Node TConvProofGenerator::getRewriteStep(Node t,
117 uint32_t tctx,
118 bool isPre) const
119 {
120 Node thash = t;
121 if (d_tcontext != nullptr)
122 {
123 thash = TCtxNode::computeNodeHash(t, tctx);
124 }
125 return getRewriteStepInternal(thash, isPre);
126 }
127
128 Node TConvProofGenerator::registerRewriteStep(Node t,
129 Node s,
130 uint32_t tctx,
131 bool isPre)
132 {
133 if (t == s)
134 {
135 return Node::null();
136 }
137 Node thash = t;
138 if (d_tcontext != nullptr)
139 {
140 thash = TCtxNode::computeNodeHash(t, tctx);
141 }
142 else
143 {
144 // don't use term context ids if not using term context
145 Assert(tctx == 0);
146 }
147 // should not rewrite term to two different things
148 if (!getRewriteStepInternal(thash, isPre).isNull())
149 {
150 Assert(getRewriteStepInternal(thash, isPre) == s)
151 << identify() << " rewriting " << t << " to both " << s << " and "
152 << getRewriteStepInternal(thash, isPre);
153 return Node::null();
154 }
155 NodeNodeMap& rm = isPre ? d_preRewriteMap : d_postRewriteMap;
156 rm[thash] = s;
157 if (d_cpolicy == TConvCachePolicy::DYNAMIC)
158 {
159 // clear the cache
160 d_cache.clear();
161 }
162 return t.eqNode(s);
163 }
164
165 std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
166 {
167 Trace("tconv-pf-gen") << "TConvProofGenerator::getProofFor: " << identify()
168 << ": " << f << std::endl;
169 if (f.getKind() != EQUAL)
170 {
171 std::stringstream serr;
172 serr << "TConvProofGenerator::getProofFor: " << identify()
173 << ": fail, non-equality " << f;
174 Unhandled() << serr.str();
175 Trace("tconv-pf-gen") << serr.str() << std::endl;
176 return nullptr;
177 }
178 // we use the existing proofs
179 LazyCDProof lpf(
180 d_proof.getManager(), &d_proof, nullptr, d_name + "::LazyCDProof");
181 if (f[0] == f[1])
182 {
183 // assertion failure in debug
184 Assert(false) << "TConvProofGenerator::getProofFor: " << identify()
185 << ": don't ask for trivial proofs";
186 lpf.addStep(f, PfRule::REFL, {}, {f[0]});
187 }
188 else
189 {
190 Node conc = getProofForRewriting(f[0], lpf, d_tcontext);
191 if (conc != f)
192 {
193 bool debugTraceEnabled = Trace.isOn("tconv-pf-gen-debug");
194 Assert(conc.getKind() == EQUAL && conc[0] == f[0]);
195 std::stringstream serr;
196 serr << "TConvProofGenerator::getProofFor: " << toStringDebug()
197 << ": failed, mismatch";
198 if (!debugTraceEnabled)
199 {
200 serr << " (see -t tconv-pf-gen-debug for details)";
201 }
202 serr << std::endl;
203 serr << " source: " << f[0] << std::endl;
204 serr << " requested conclusion: " << f[1] << std::endl;
205 serr << "conclusion from generator: " << conc[1] << std::endl;
206
207 if (debugTraceEnabled)
208 {
209 Trace("tconv-pf-gen-debug") << "Printing rewrite steps..." << std::endl;
210 for (size_t r = 0; r < 2; r++)
211 {
212 const NodeNodeMap& rm = r == 0 ? d_preRewriteMap : d_postRewriteMap;
213 serr << "Rewrite steps (" << (r == 0 ? "pre" : "post")
214 << "):" << std::endl;
215 for (NodeNodeMap::const_iterator it = rm.begin(); it != rm.end();
216 ++it)
217 {
218 serr << (*it).first << " -> " << (*it).second << std::endl;
219 }
220 }
221 }
222 Unhandled() << serr.str();
223 return nullptr;
224 }
225 }
226 std::shared_ptr<ProofNode> pfn = lpf.getProofFor(f);
227 Trace("tconv-pf-gen") << "... success" << std::endl;
228 Assert (pfn!=nullptr);
229 Trace("tconv-pf-gen-debug") << "... proof is " << *pfn << std::endl;
230 return pfn;
231 }
232
233 Node TConvProofGenerator::getProofForRewriting(Node t,
234 LazyCDProof& pf,
235 TermContext* tctx)
236 {
237 NodeManager* nm = NodeManager::currentNM();
238 // Invariant: if visited[hash(t)] = s or rewritten[hash(t)] = s and t,s are
239 // distinct, then pf is able to generate a proof of t=s. We must
240 // Node in the domains of the maps below due to hashing creating new (SEXPR)
241 // nodes.
242
243 // the final rewritten form of terms
244 std::unordered_map<Node, Node, TNodeHashFunction> visited;
245 // the rewritten form of terms we have processed so far
246 std::unordered_map<Node, Node, TNodeHashFunction> rewritten;
247 std::unordered_map<Node, Node, TNodeHashFunction>::iterator it;
248 std::unordered_map<Node, Node, TNodeHashFunction>::iterator itr;
249 std::map<Node, std::shared_ptr<ProofNode> >::iterator itc;
250 Trace("tconv-pf-gen-rewrite")
251 << "TConvProofGenerator::getProofForRewriting: " << toStringDebug()
252 << std::endl;
253 Trace("tconv-pf-gen-rewrite") << "Input: " << t << std::endl;
254 // if provided, we use term context for cache
255 std::shared_ptr<TCtxStack> visitctx;
256 // otherwise, visit is used if we don't have a term context
257 std::vector<TNode> visit;
258 Node tinitialHash;
259 if (tctx != nullptr)
260 {
261 visitctx = std::make_shared<TCtxStack>(tctx);
262 visitctx->pushInitial(t);
263 tinitialHash = TCtxNode::computeNodeHash(t, tctx->initialValue());
264 }
265 else
266 {
267 visit.push_back(t);
268 tinitialHash = t;
269 }
270 Node cur;
271 uint32_t curCVal = 0;
272 Node curHash;
273 do
274 {
275 // pop the top element
276 if (tctx != nullptr)
277 {
278 std::pair<Node, uint32_t> curPair = visitctx->getCurrent();
279 cur = curPair.first;
280 curCVal = curPair.second;
281 curHash = TCtxNode::computeNodeHash(cur, curCVal);
282 visitctx->pop();
283 }
284 else
285 {
286 cur = visit.back();
287 curHash = cur;
288 visit.pop_back();
289 }
290 Trace("tconv-pf-gen-rewrite") << "* visit : " << curHash << std::endl;
291 // has the proof for cur been cached?
292 itc = d_cache.find(curHash);
293 if (itc != d_cache.end())
294 {
295 Node res = itc->second->getResult();
296 Assert(res.getKind() == EQUAL);
297 Assert(!res[1].isNull());
298 visited[curHash] = res[1];
299 pf.addProof(itc->second);
300 continue;
301 }
302 it = visited.find(curHash);
303 if (it == visited.end())
304 {
305 Trace("tconv-pf-gen-rewrite") << "- previsit" << std::endl;
306 visited[curHash] = Node::null();
307 // did we rewrite the current node (at pre-rewrite)?
308 Node rcur = getRewriteStepInternal(curHash, true);
309 if (!rcur.isNull())
310 {
311 Trace("tconv-pf-gen-rewrite")
312 << "*** " << curHash << " prerewrites to " << rcur << std::endl;
313 // d_proof has a proof of cur = rcur. Hence there is nothing
314 // to do here, as pf will reference d_proof to get its proof.
315 if (d_policy == TConvPolicy::FIXPOINT)
316 {
317 // It may be the case that rcur also rewrites, thus we cannot assign
318 // the final rewritten form for cur yet. Instead we revisit cur after
319 // finishing visiting rcur.
320 rewritten[curHash] = rcur;
321 if (tctx != nullptr)
322 {
323 visitctx->push(cur, curCVal);
324 visitctx->push(rcur, curCVal);
325 }
326 else
327 {
328 visit.push_back(cur);
329 visit.push_back(rcur);
330 }
331 }
332 else
333 {
334 Assert(d_policy == TConvPolicy::ONCE);
335 Trace("tconv-pf-gen-rewrite") << "-> (once, prewrite) " << curHash
336 << " = " << rcur << std::endl;
337 // not rewriting again, rcur is final
338 Assert(!rcur.isNull());
339 visited[curHash] = rcur;
340 doCache(curHash, cur, rcur, pf);
341 }
342 }
343 else if (tctx != nullptr)
344 {
345 visitctx->push(cur, curCVal);
346 // visit operator if apply uf
347 if (d_rewriteOps && cur.getKind() == APPLY_UF)
348 {
349 visitctx->pushOp(cur, curCVal);
350 }
351 visitctx->pushChildren(cur, curCVal);
352 }
353 else
354 {
355 visit.push_back(cur);
356 // visit operator if apply uf
357 if (d_rewriteOps && cur.getKind() == APPLY_UF)
358 {
359 visit.push_back(cur.getOperator());
360 }
361 visit.insert(visit.end(), cur.begin(), cur.end());
362 }
363 }
364 else if (it->second.isNull())
365 {
366 itr = rewritten.find(curHash);
367 if (itr != rewritten.end())
368 {
369 // only can generate partially rewritten nodes when rewrite again is
370 // true.
371 Assert(d_policy != TConvPolicy::ONCE);
372 // if it was rewritten, check the status of the rewritten node,
373 // which should be finished now
374 Node rcur = itr->second;
375 Trace("tconv-pf-gen-rewrite")
376 << "- postvisit, previously rewritten to " << rcur << std::endl;
377 Node rcurHash = rcur;
378 if (tctx != nullptr)
379 {
380 rcurHash = TCtxNode::computeNodeHash(rcur, curCVal);
381 }
382 Assert(cur != rcur);
383 // the final rewritten form of cur is the final form of rcur
384 Node rcurFinal = visited[rcurHash];
385 Assert(!rcurFinal.isNull());
386 if (rcurFinal != rcur)
387 {
388 // must connect via TRANS
389 std::vector<Node> pfChildren;
390 pfChildren.push_back(cur.eqNode(rcur));
391 pfChildren.push_back(rcur.eqNode(rcurFinal));
392 Node result = cur.eqNode(rcurFinal);
393 pf.addStep(result, PfRule::TRANS, pfChildren, {});
394 }
395 Trace("tconv-pf-gen-rewrite")
396 << "-> (rewritten postrewrite) " << curHash << " = " << rcurFinal
397 << std::endl;
398 visited[curHash] = rcurFinal;
399 doCache(curHash, cur, rcurFinal, pf);
400 }
401 else
402 {
403 Trace("tconv-pf-gen-rewrite") << "- postvisit" << std::endl;
404 Node ret = cur;
405 Node retHash = curHash;
406 bool childChanged = false;
407 std::vector<Node> children;
408 Kind ck = cur.getKind();
409 if (d_rewriteOps && ck == APPLY_UF)
410 {
411 // the operator of APPLY_UF is visited
412 Node cop = cur.getOperator();
413 if (tctx != nullptr)
414 {
415 uint32_t coval = tctx->computeValueOp(cur, curCVal);
416 Node coHash = TCtxNode::computeNodeHash(cop, coval);
417 it = visited.find(coHash);
418 }
419 else
420 {
421 it = visited.find(cop);
422 }
423 Assert(it != visited.end());
424 Assert(!it->second.isNull());
425 childChanged = childChanged || cop != it->second;
426 children.push_back(it->second);
427 }
428 else if (cur.getMetaKind() == metakind::PARAMETERIZED)
429 {
430 // all other parametrized operators are unchanged
431 children.push_back(cur.getOperator());
432 }
433 // get the results of the children
434 if (tctx != nullptr)
435 {
436 for (size_t i = 0, nchild = cur.getNumChildren(); i < nchild; i++)
437 {
438 Node cn = cur[i];
439 uint32_t cnval = tctx->computeValue(cur, curCVal, i);
440 Node cnHash = TCtxNode::computeNodeHash(cn, cnval);
441 it = visited.find(cnHash);
442 Assert(it != visited.end());
443 Assert(!it->second.isNull());
444 childChanged = childChanged || cn != it->second;
445 children.push_back(it->second);
446 }
447 }
448 else
449 {
450 // can use simple loop if not term-context-sensitive
451 for (const Node& cn : cur)
452 {
453 it = visited.find(cn);
454 Assert(it != visited.end());
455 Assert(!it->second.isNull());
456 childChanged = childChanged || cn != it->second;
457 children.push_back(it->second);
458 }
459 }
460 if (childChanged)
461 {
462 ret = nm->mkNode(ck, children);
463 rewritten[curHash] = ret;
464 // congruence to show (cur = ret)
465 PfRule congRule = PfRule::CONG;
466 std::vector<Node> pfChildren;
467 std::vector<Node> pfArgs;
468 pfArgs.push_back(ProofRuleChecker::mkKindNode(ck));
469 if (ck == APPLY_UF && children[0] != cur.getOperator())
470 {
471 // use HO_CONG if the operator changed
472 congRule = PfRule::HO_CONG;
473 pfChildren.push_back(cur.getOperator().eqNode(children[0]));
474 }
475 else if (kind::metaKindOf(ck) == kind::metakind::PARAMETERIZED)
476 {
477 pfArgs.push_back(cur.getOperator());
478 }
479 for (size_t i = 0, size = cur.getNumChildren(); i < size; i++)
480 {
481 if (cur[i] == ret[i])
482 {
483 // ensure REFL proof for unchanged children
484 pf.addStep(cur[i].eqNode(cur[i]), PfRule::REFL, {}, {cur[i]});
485 }
486 pfChildren.push_back(cur[i].eqNode(ret[i]));
487 }
488 Node result = cur.eqNode(ret);
489 pf.addStep(result, congRule, pfChildren, pfArgs);
490 // must update the hash
491 retHash = ret;
492 if (tctx != nullptr)
493 {
494 retHash = TCtxNode::computeNodeHash(ret, curCVal);
495 }
496 }
497 else if (tctx != nullptr)
498 {
499 // now we need the hash
500 retHash = TCtxNode::computeNodeHash(cur, curCVal);
501 }
502 // did we rewrite ret (at post-rewrite)?
503 Node rret = getRewriteStepInternal(retHash, false);
504 if (!rret.isNull() && d_policy == TConvPolicy::FIXPOINT)
505 {
506 Trace("tconv-pf-gen-rewrite")
507 << "*** " << retHash << " postrewrites to " << rret << std::endl;
508 // d_proof should have a proof of ret = rret, hence nothing to do
509 // here, for the same reasons as above. It also may be the case that
510 // rret rewrites, hence we must revisit ret.
511 rewritten[retHash] = rret;
512 if (tctx != nullptr)
513 {
514 if (cur != ret)
515 {
516 visitctx->push(cur, curCVal);
517 }
518 visitctx->push(ret, curCVal);
519 visitctx->push(rret, curCVal);
520 }
521 else
522 {
523 if (cur != ret)
524 {
525 visit.push_back(cur);
526 }
527 visit.push_back(ret);
528 visit.push_back(rret);
529 }
530 }
531 else
532 {
533 // take its rewrite if it rewrote and we have ONCE rewriting policy
534 ret = rret.isNull() ? ret : rret;
535 Trace("tconv-pf-gen-rewrite")
536 << "-> (postrewrite) " << curHash << " = " << ret << std::endl;
537 // it is final
538 Assert(!ret.isNull());
539 visited[curHash] = ret;
540 doCache(curHash, cur, ret, pf);
541 }
542 }
543 }
544 else
545 {
546 Trace("tconv-pf-gen-rewrite") << "- already visited" << std::endl;
547 }
548 } while (!(tctx != nullptr ? visitctx->empty() : visit.empty()));
549 Assert(visited.find(tinitialHash) != visited.end());
550 Assert(!visited.find(tinitialHash)->second.isNull());
551 Trace("tconv-pf-gen-rewrite")
552 << "...finished, return " << visited[tinitialHash] << std::endl;
553 // return the conclusion of the overall proof
554 return t.eqNode(visited[tinitialHash]);
555 }
556
557 void TConvProofGenerator::doCache(Node curHash,
558 Node cur,
559 Node r,
560 LazyCDProof& pf)
561 {
562 if (d_cpolicy != TConvCachePolicy::NEVER)
563 {
564 Node eq = cur.eqNode(r);
565 d_cache[curHash] = pf.getProofFor(eq);
566 }
567 }
568
569 Node TConvProofGenerator::getRewriteStepInternal(Node t, bool isPre) const
570 {
571 const NodeNodeMap& rm = isPre ? d_preRewriteMap : d_postRewriteMap;
572 NodeNodeMap::const_iterator it = rm.find(t);
573 if (it == rm.end())
574 {
575 return Node::null();
576 }
577 return (*it).second;
578 }
579 std::string TConvProofGenerator::identify() const { return d_name; }
580
581 std::string TConvProofGenerator::toStringDebug() const
582 {
583 std::stringstream ss;
584 ss << identify() << " (policy=" << d_policy << ", cache policy=" << d_cpolicy
585 << (d_tcontext != nullptr ? ", term-context-sensitive" : "") << ")";
586 return ss.str();
587 }
588
589 } // namespace CVC4