(proof-new) SMT proof postprocess callback (#4883)
[cvc5.git] / src / smt / proof_post_processor.cpp
1 /********************* */
2 /*! \file proof_post_processor.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of module for processing proof nodes
13 **/
14
15 #include "smt/proof_post_processor.h"
16
17 #include "expr/skolem_manager.h"
18 #include "options/smt_options.h"
19 #include "preprocessing/assertion_pipeline.h"
20 #include "smt/smt_engine.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/builtin/proof_checker.h"
23 #include "theory/rewriter.h"
24
25 using namespace CVC4::kind;
26 using namespace CVC4::theory;
27
28 namespace CVC4 {
29 namespace smt {
30
31 ProofPostprocessCallback::ProofPostprocessCallback(ProofNodeManager* pnm,
32 SmtEngine* smte,
33 ProofGenerator* pppg)
34 : d_pnm(pnm), d_smte(smte), d_pppg(pppg), d_wfpm(pnm)
35 {
36 d_true = NodeManager::currentNM()->mkConst(true);
37 // always check whether to update ASSUME
38 d_elimRules.insert(PfRule::ASSUME);
39 }
40
41 void ProofPostprocessCallback::initializeUpdate()
42 {
43 d_assumpToProof.clear();
44 d_wfAssumptions.clear();
45 }
46
47 void ProofPostprocessCallback::setEliminateRule(PfRule rule)
48 {
49 d_elimRules.insert(rule);
50 }
51
52 bool ProofPostprocessCallback::shouldUpdate(ProofNode* pn)
53 {
54 return d_elimRules.find(pn->getRule()) != d_elimRules.end();
55 }
56
57 bool ProofPostprocessCallback::update(Node res,
58 PfRule id,
59 const std::vector<Node>& children,
60 const std::vector<Node>& args,
61 CDProof* cdp)
62 {
63 Trace("smt-proof-pp-debug") << "- Post process " << id << " " << children
64 << " / " << args << std::endl;
65
66 if (id == PfRule::ASSUME)
67 {
68 // we cache based on the assumption node, not the proof node, since there
69 // may be multiple occurrences of the same node.
70 Node f = args[0];
71 std::shared_ptr<ProofNode> pfn;
72 std::map<Node, std::shared_ptr<ProofNode>>::iterator it =
73 d_assumpToProof.find(f);
74 if (it != d_assumpToProof.end())
75 {
76 Trace("smt-proof-pp-debug") << "...already computed" << std::endl;
77 pfn = it->second;
78 }
79 else
80 {
81 Assert(d_pppg != nullptr);
82 // get proof from preprocess proof generator
83 pfn = d_pppg->getProofFor(f);
84 // print for debugging
85 if (pfn == nullptr)
86 {
87 Trace("smt-proof-pp-debug")
88 << "...no proof, possibly an input assumption" << std::endl;
89 }
90 else
91 {
92 Assert(pfn->getResult() == f);
93 if (Trace.isOn("smt-proof-pp"))
94 {
95 Trace("smt-proof-pp")
96 << "=== Connect proof for preprocessing: " << f << std::endl;
97 Trace("smt-proof-pp") << *pfn.get() << std::endl;
98 }
99 }
100 d_assumpToProof[f] = pfn;
101 }
102 if (pfn == nullptr)
103 {
104 // no update
105 return false;
106 }
107 // connect the proof
108 cdp->addProof(pfn);
109 return true;
110 }
111 Node ret = expandMacros(id, children, args, cdp);
112 Trace("smt-proof-pp-debug") << "...expanded = " << !ret.isNull() << std::endl;
113 return !ret.isNull();
114 }
115
116 Node ProofPostprocessCallback::expandMacros(PfRule id,
117 const std::vector<Node>& children,
118 const std::vector<Node>& args,
119 CDProof* cdp)
120 {
121 if (d_elimRules.find(id) == d_elimRules.end())
122 {
123 // not eliminated
124 return Node::null();
125 }
126 // macro elimination
127 if (id == PfRule::MACRO_SR_EQ_INTRO)
128 {
129 // (TRANS
130 // (SUBS <children> :args args[0:1])
131 // (REWRITE :args <t.substitute(x1,t1). ... .substitute(xn,tn)> args[2]))
132 std::vector<Node> tchildren;
133 Node t = args[0];
134 Node ts;
135 if (!children.empty())
136 {
137 std::vector<Node> sargs;
138 sargs.push_back(t);
139 MethodId sid = MethodId::SB_DEFAULT;
140 if (args.size() >= 2)
141 {
142 if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid))
143 {
144 sargs.push_back(args[1]);
145 }
146 }
147 ts =
148 builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid);
149 if (ts != t)
150 {
151 Node eq = t.eqNode(ts);
152 // apply SUBS proof rule if necessary
153 if (!update(eq, PfRule::SUBS, children, sargs, cdp))
154 {
155 // if not elimianted, add as step
156 cdp->addStep(eq, PfRule::SUBS, children, sargs);
157 }
158 tchildren.push_back(eq);
159 }
160 }
161 else
162 {
163 // no substitute
164 ts = t;
165 }
166 std::vector<Node> rargs;
167 rargs.push_back(ts);
168 MethodId rid = MethodId::RW_REWRITE;
169 if (args.size() >= 3)
170 {
171 if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid))
172 {
173 rargs.push_back(args[2]);
174 }
175 }
176 builtin::BuiltinProofRuleChecker* builtinPfC =
177 static_cast<builtin::BuiltinProofRuleChecker*>(
178 d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO));
179 Node tr = builtinPfC->applyRewrite(ts, rid);
180 if (ts != tr)
181 {
182 Node eq = ts.eqNode(tr);
183 // apply REWRITE proof rule
184 if (!update(eq, PfRule::REWRITE, {}, rargs, cdp))
185 {
186 // if not elimianted, add as step
187 cdp->addStep(eq, PfRule::REWRITE, {}, rargs);
188 }
189 tchildren.push_back(eq);
190 }
191 if (t == tr)
192 {
193 // typically not necessary, but done to be robust
194 cdp->addStep(t.eqNode(tr), PfRule::REFL, {}, {t});
195 return t.eqNode(tr);
196 }
197 // must add TRANS if two step
198 return addProofForTrans(tchildren, cdp);
199 }
200 else if (id == PfRule::MACRO_SR_PRED_INTRO)
201 {
202 std::vector<Node> tchildren;
203 std::vector<Node> sargs = args;
204 // take into account witness form, if necessary
205 if (d_wfpm.requiresWitnessFormIntro(args[0]))
206 {
207 Node weq = addProofForWitnessForm(args[0], cdp);
208 tchildren.push_back(weq);
209 // replace the first argument
210 sargs[0] = weq[1];
211 }
212 // (TRUE_ELIM
213 // (TRANS
214 // ... proof of t = toWitness(t) ...
215 // (MACRO_SR_EQ_INTRO <children> :args (toWitness(t) args[1:]))))
216 // We call the expandMacros method on MACRO_SR_EQ_INTRO, where notice
217 // that this rule application is immediately expanded in the recursive
218 // call and not added to the proof.
219 Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, children, sargs, cdp);
220 tchildren.push_back(conc);
221 Assert(!conc.isNull() && conc.getKind() == EQUAL && conc[0] == sargs[0]
222 && conc[1] == d_true);
223 // transitivity if necessary
224 Node eq = addProofForTrans(tchildren, cdp);
225
226 cdp->addStep(eq[0], PfRule::TRUE_ELIM, {eq}, {});
227 Assert(eq[0] == args[0]);
228 return eq[0];
229 }
230 else if (id == PfRule::MACRO_SR_PRED_ELIM)
231 {
232 // (TRUE_ELIM
233 // (TRANS
234 // (SYMM (MACRO_SR_EQ_INTRO children[1:] :args children[0] ++ args))
235 // (TRUE_INTRO children[0])))
236 std::vector<Node> schildren(children.begin() + 1, children.end());
237 std::vector<Node> srargs;
238 srargs.push_back(children[0]);
239 srargs.insert(srargs.end(), args.begin(), args.end());
240 Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, srargs, cdp);
241 Assert(!conc.isNull() && conc.getKind() == EQUAL && conc[0] == children[0]);
242
243 Node eq1 = children[0].eqNode(d_true);
244 cdp->addStep(eq1, PfRule::TRUE_INTRO, {children[0]}, {});
245
246 Node concSym = conc[1].eqNode(conc[0]);
247 Node eq2 = conc[1].eqNode(d_true);
248 cdp->addStep(eq2, PfRule::TRANS, {concSym, eq1}, {});
249
250 cdp->addStep(conc[1], PfRule::TRUE_ELIM, {eq2}, {});
251 return conc[1];
252 }
253 else if (id == PfRule::MACRO_SR_PRED_TRANSFORM)
254 {
255 // (TRUE_ELIM
256 // (TRANS
257 // (MACRO_SR_EQ_INTRO children[1:] :args <args>)
258 // ... proof of a = wa
259 // (MACRO_SR_EQ_INTRO {} wa)
260 // (SYMM
261 // (MACRO_SR_EQ_INTRO children[1:] :args (children[0] args[1:]))
262 // ... proof of c = wc
263 // (MACRO_SR_EQ_INTRO {} wc))
264 // (TRUE_INTRO children[0])))
265 // where
266 // wa = toWitness(apply_SR(args[0])) and
267 // wc = toWitness(apply_SR(children[0])).
268 Trace("smt-proof-pp-debug")
269 << "Transform " << children[0] << " == " << args[0] << std::endl;
270 if (CDProof::isSame(children[0], args[0]))
271 {
272 // nothing to do
273 return children[0];
274 }
275 std::vector<Node> tchildren;
276 std::vector<Node> schildren(children.begin() + 1, children.end());
277 std::vector<Node> sargs = args;
278 // first, compute if we need
279 bool reqWitness = d_wfpm.requiresWitnessFormTransform(children[0], args[0]);
280 // convert both sides, in three steps, take symmetry of second chain
281 for (unsigned r = 0; r < 2; r++)
282 {
283 std::vector<Node> tchildrenr;
284 // first rewrite args[0], then children[0]
285 sargs[0] = r == 0 ? args[0] : children[0];
286 // t = apply_SR(t)
287 Node eq = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, sargs, cdp);
288 Trace("smt-proof-pp-debug")
289 << "transform subs_rewrite (" << r << "): " << eq << std::endl;
290 Assert(!eq.isNull() && eq.getKind() == EQUAL && eq[0] == sargs[0]);
291 addToTransChildren(eq, tchildrenr);
292 // apply_SR(t) = toWitness(apply_SR(t))
293 if (reqWitness)
294 {
295 Node weq = addProofForWitnessForm(eq[1], cdp);
296 Trace("smt-proof-pp-debug")
297 << "transform toWitness (" << r << "): " << weq << std::endl;
298 if (addToTransChildren(weq, tchildrenr))
299 {
300 sargs[0] = weq[1];
301 // toWitness(apply_SR(t)) = apply_SR(toWitness(apply_SR(t)))
302 // rewrite again, don't need substitution
303 Node weqr = expandMacros(PfRule::MACRO_SR_EQ_INTRO, {}, sargs, cdp);
304 Trace("smt-proof-pp-debug") << "transform rewrite_witness (" << r
305 << "): " << weqr << std::endl;
306 addToTransChildren(weqr, tchildrenr);
307 }
308 }
309 Trace("smt-proof-pp-debug")
310 << "transform connect (" << r << ")" << std::endl;
311 // add to overall chain
312 if (r == 0)
313 {
314 // add the current chain to the overall chain
315 tchildren.insert(tchildren.end(), tchildrenr.begin(), tchildrenr.end());
316 }
317 else
318 {
319 // add the current chain to cdp
320 Node eqr = addProofForTrans(tchildrenr, cdp);
321 if (!eqr.isNull())
322 {
323 // take symmetry of above and add it to the overall chain
324 addToTransChildren(eqr, tchildren, true);
325 }
326 }
327 Trace("smt-proof-pp-debug")
328 << "transform finish (" << r << ")" << std::endl;
329 }
330
331 // children[0] = true
332 Node eq3 = children[0].eqNode(d_true);
333 Trace("smt-proof-pp-debug") << "transform true_intro: " << eq3 << std::endl;
334 cdp->addStep(eq3, PfRule::TRUE_INTRO, {children[0]}, {});
335 addToTransChildren(eq3, tchildren);
336
337 // apply transitivity if necessary
338 Node eq = addProofForTrans(tchildren, cdp);
339
340 cdp->addStep(args[0], PfRule::TRUE_ELIM, {eq}, {});
341 return args[0];
342 }
343 else if (id == PfRule::SUBS)
344 {
345 // Notice that a naive way to reconstruct SUBS is to do a term conversion
346 // proof for each substitution.
347 // The proof of f(a) * { a -> g(b) } * { b -> c } = f(g(c)) is:
348 // TRANS( CONG{f}( a=g(b) ), CONG{f}( CONG{g}( b=c ) ) )
349 // Notice that more optimal proofs are possible that do a single traversal
350 // over t. This is done by applying later substitutions to the range of
351 // previous substitutions, until a final simultaneous substitution is
352 // applied to t. For instance, in the above example, we first prove:
353 // CONG{g}( b = c )
354 // by applying the second substitution { b -> c } to the range of the first,
355 // giving us a proof of g(b)=g(c). We then construct the updated proof
356 // by tranitivity:
357 // TRANS( a=g(b), CONG{g}( b=c ) )
358 // We then apply the substitution { a -> g(c), b -> c } to f(a), to obtain:
359 // CONG{f}( TRANS( a=g(b), CONG{g}( b=c ) ) )
360 // which notice is more compact than the proof above.
361 Node t = args[0];
362 // get the kind of substitution
363 MethodId ids = MethodId::SB_DEFAULT;
364 if (args.size() >= 2)
365 {
366 builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids);
367 }
368 std::vector<std::shared_ptr<CDProof>> pfs;
369 std::vector<Node> vvec;
370 std::vector<Node> svec;
371 std::vector<ProofGenerator*> pgs;
372 for (size_t i = 0, nchild = children.size(); i < nchild; i++)
373 {
374 // process in reverse order
375 size_t index = nchild - (i + 1);
376 // get the substitution
377 TNode var, subs;
378 builtin::BuiltinProofRuleChecker::getSubstitution(
379 children[index], var, subs, ids);
380 // apply the current substitution to the range
381 if (!vvec.empty())
382 {
383 Node ss =
384 subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
385 if (ss != subs)
386 {
387 // make the proof for the tranitivity step
388 std::shared_ptr<CDProof> pf = std::make_shared<CDProof>(d_pnm);
389 pfs.push_back(pf);
390 // prove the updated substitution
391 TConvProofGenerator tcg(d_pnm, nullptr, TConvPolicy::ONCE);
392 // add previous rewrite steps
393 for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
394 {
395 tcg.addRewriteStep(vvec[j], svec[j], pgs[j]);
396 }
397 // get the proof for the update to the current substitution
398 Node seqss = subs.eqNode(ss);
399 std::shared_ptr<ProofNode> pfn = tcg.getProofFor(seqss);
400 Assert(pfn != nullptr);
401 // add the proof
402 pf->addProof(pfn);
403 // get proof for children[i] from cdp
404 pfn = cdp->getProofFor(children[i]);
405 pf->addProof(pfn);
406 // ensure we have a proof of var = subs
407 Node veqs = var.eqNode(subs);
408 if (veqs != children[index])
409 {
410 // should be true intro or false intro
411 Assert(subs.isConst());
412 pf->addStep(veqs,
413 subs.getConst<bool>() ? PfRule::TRUE_INTRO
414 : PfRule::FALSE_INTRO,
415 {children[index]},
416 {});
417 }
418 pf->addStep(var.eqNode(ss), PfRule::TRANS, {veqs, seqss}, {});
419 // add to the substitution
420 vvec.push_back(var);
421 svec.push_back(ss);
422 pgs.push_back(pf.get());
423 continue;
424 }
425 }
426 // just use equality from CDProof
427 vvec.push_back(var);
428 svec.push_back(subs);
429 pgs.push_back(cdp);
430 }
431 Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
432 Node eq = t.eqNode(ts);
433 if (ts != t)
434 {
435 // should be implied by the substitution now
436 TConvProofGenerator tcpg(d_pnm, nullptr, TConvPolicy::ONCE);
437 for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
438 {
439 tcpg.addRewriteStep(vvec[j], svec[j], pgs[j]);
440 }
441 // add the proof constructed by the term conversion utility
442 std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
443 // should give a proof, if not, then tcpg does not agree with the
444 // substitution.
445 Assert(pfn != nullptr);
446 if (pfn != nullptr)
447 {
448 cdp->addProof(pfn);
449 }
450 }
451 else
452 {
453 // should not be necessary typically
454 cdp->addStep(eq, PfRule::REFL, {}, {t});
455 }
456 return eq;
457 }
458 else if (id == PfRule::REWRITE)
459 {
460 // get the kind of rewrite
461 MethodId idr = MethodId::RW_REWRITE;
462 if (args.size() >= 2)
463 {
464 builtin::BuiltinProofRuleChecker::getMethodId(args[1], idr);
465 }
466 builtin::BuiltinProofRuleChecker* builtinPfC =
467 static_cast<builtin::BuiltinProofRuleChecker*>(
468 d_pnm->getChecker()->getCheckerFor(PfRule::REWRITE));
469 Node ret = builtinPfC->applyRewrite(args[0], idr);
470 Node eq = args[0].eqNode(ret);
471 if (idr == MethodId::RW_REWRITE || idr == MethodId::RW_REWRITE_EQ_EXT)
472 {
473 // rewrites from theory::Rewriter
474 // automatically expand THEORY_REWRITE as well here if set
475 bool elimTR =
476 (d_elimRules.find(PfRule::THEORY_REWRITE) != d_elimRules.end());
477 bool isExtEq = (idr == MethodId::RW_REWRITE_EQ_EXT);
478 // use rewrite with proof interface
479 Rewriter* rr = d_smte->getRewriter();
480 TrustNode trn = rr->rewriteWithProof(args[0], elimTR, isExtEq);
481 std::shared_ptr<ProofNode> pfn =
482 trn.getGenerator()->getProofFor(trn.getProven());
483 cdp->addProof(pfn);
484 Assert(trn.getNode() == ret);
485 }
486 else if (idr == MethodId::RW_EVALUATE)
487 {
488 // change to evaluate, which is never eliminated
489 cdp->addStep(eq, PfRule::EVALUATE, {}, {args[0]});
490 }
491 else
492 {
493 // don't know how to eliminate
494 return Node::null();
495 }
496 if (args[0] == ret)
497 {
498 // should not be necessary typically
499 cdp->addStep(eq, PfRule::REFL, {}, {args[0]});
500 }
501 return eq;
502 }
503
504 // TRUST, PREPROCESS, THEORY_LEMMA, THEORY_PREPROCESS?
505
506 return Node::null();
507 }
508
509 Node ProofPostprocessCallback::addProofForWitnessForm(Node t, CDProof* cdp)
510 {
511 Node tw = SkolemManager::getWitnessForm(t);
512 Node eq = t.eqNode(tw);
513 if (t == tw)
514 {
515 // not necessary, add REFL step
516 cdp->addStep(eq, PfRule::REFL, {}, {t});
517 return eq;
518 }
519 std::shared_ptr<ProofNode> pn = d_wfpm.getProofFor(eq);
520 if (pn != nullptr)
521 {
522 // add the proof
523 cdp->addProof(pn);
524 }
525 else
526 {
527 Assert(false) << "ProofPostprocessCallback::addProofForWitnessForm: failed "
528 "to add proof for witness form of "
529 << t;
530 }
531 return eq;
532 }
533
534 Node ProofPostprocessCallback::addProofForTrans(
535 const std::vector<Node>& tchildren, CDProof* cdp)
536 {
537 size_t tsize = tchildren.size();
538 if (tsize > 1)
539 {
540 Node lhs = tchildren[0][0];
541 Node rhs = tchildren[tsize - 1][1];
542 Node eq = lhs.eqNode(rhs);
543 cdp->addStep(eq, PfRule::TRANS, tchildren, {});
544 return eq;
545 }
546 else if (tsize == 1)
547 {
548 return tchildren[0];
549 }
550 return Node::null();
551 }
552
553 bool ProofPostprocessCallback::addToTransChildren(Node eq,
554 std::vector<Node>& tchildren,
555 bool isSymm)
556 {
557 Assert(!eq.isNull());
558 Assert(eq.getKind() == kind::EQUAL);
559 if (eq[0] == eq[1])
560 {
561 return false;
562 }
563 Node equ = isSymm ? eq[1].eqNode(eq[0]) : eq;
564 Assert(tchildren.empty()
565 || (tchildren[tchildren.size() - 1].getKind() == kind::EQUAL
566 && tchildren[tchildren.size() - 1][1] == equ[0]));
567 tchildren.push_back(equ);
568 return true;
569 }
570
571 } // namespace smt
572 } // namespace CVC4