[proof-new] Fix bug in expansion of MACRO_RESOLUTION (#5845)
[cvc5.git] / src / smt / proof_post_processor.cpp
1 /********************* */
2 /*! \file proof_post_processor.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Haniel Barbosa
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of module for processing proof nodes
13 **/
14
15 #include "smt/proof_post_processor.h"
16
17 #include "expr/skolem_manager.h"
18 #include "options/smt_options.h"
19 #include "preprocessing/assertion_pipeline.h"
20 #include "smt/smt_engine.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/builtin/proof_checker.h"
23 #include "theory/rewriter.h"
24 #include "theory/theory.h"
25
26 using namespace CVC4::kind;
27 using namespace CVC4::theory;
28
29 namespace CVC4 {
30 namespace smt {
31
32 ProofPostprocessCallback::ProofPostprocessCallback(ProofNodeManager* pnm,
33 SmtEngine* smte,
34 ProofGenerator* pppg)
35 : d_pnm(pnm), d_smte(smte), d_pppg(pppg), d_wfpm(pnm)
36 {
37 d_true = NodeManager::currentNM()->mkConst(true);
38 // always check whether to update ASSUME
39 d_elimRules.insert(PfRule::ASSUME);
40 }
41
42 void ProofPostprocessCallback::initializeUpdate()
43 {
44 d_assumpToProof.clear();
45 d_wfAssumptions.clear();
46 }
47
48 void ProofPostprocessCallback::setEliminateRule(PfRule rule)
49 {
50 d_elimRules.insert(rule);
51 }
52
53 bool ProofPostprocessCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
54 bool& continueUpdate)
55 {
56 return d_elimRules.find(pn->getRule()) != d_elimRules.end();
57 }
58
59 bool ProofPostprocessCallback::update(Node res,
60 PfRule id,
61 const std::vector<Node>& children,
62 const std::vector<Node>& args,
63 CDProof* cdp,
64 bool& continueUpdate)
65 {
66 Trace("smt-proof-pp-debug") << "- Post process " << id << " " << children
67 << " / " << args << std::endl;
68
69 if (id == PfRule::ASSUME)
70 {
71 // we cache based on the assumption node, not the proof node, since there
72 // may be multiple occurrences of the same node.
73 Node f = args[0];
74 std::shared_ptr<ProofNode> pfn;
75 std::map<Node, std::shared_ptr<ProofNode>>::iterator it =
76 d_assumpToProof.find(f);
77 if (it != d_assumpToProof.end())
78 {
79 Trace("smt-proof-pp-debug") << "...already computed" << std::endl;
80 pfn = it->second;
81 }
82 else
83 {
84 Trace("smt-proof-pp-debug") << "...get proof" << std::endl;
85 Assert(d_pppg != nullptr);
86 // get proof from preprocess proof generator
87 pfn = d_pppg->getProofFor(f);
88 Trace("smt-proof-pp-debug") << "...finished get proof" << std::endl;
89 // print for debugging
90 if (pfn == nullptr)
91 {
92 Trace("smt-proof-pp-debug")
93 << "...no proof, possibly an input assumption" << std::endl;
94 }
95 else
96 {
97 Assert(pfn->getResult() == f);
98 if (Trace.isOn("smt-proof-pp"))
99 {
100 Trace("smt-proof-pp")
101 << "=== Connect proof for preprocessing: " << f << std::endl;
102 Trace("smt-proof-pp") << *pfn.get() << std::endl;
103 }
104 }
105 d_assumpToProof[f] = pfn;
106 }
107 if (pfn == nullptr || pfn->getRule() == PfRule::ASSUME)
108 {
109 Trace("smt-proof-pp-debug") << "...do not add proof" << std::endl;
110 // no update
111 return false;
112 }
113 Trace("smt-proof-pp-debug") << "...add proof" << std::endl;
114 // connect the proof
115 cdp->addProof(pfn);
116 return true;
117 }
118 Node ret = expandMacros(id, children, args, cdp);
119 Trace("smt-proof-pp-debug") << "...expanded = " << !ret.isNull() << std::endl;
120 return !ret.isNull();
121 }
122
123 bool ProofPostprocessCallback::updateInternal(Node res,
124 PfRule id,
125 const std::vector<Node>& children,
126 const std::vector<Node>& args,
127 CDProof* cdp)
128 {
129 bool continueUpdate = true;
130 return update(res, id, children, args, cdp, continueUpdate);
131 }
132
133 Node ProofPostprocessCallback::eliminateCrowdingLits(
134 const std::vector<Node>& clauseLits,
135 const std::vector<Node>& targetClauseLits,
136 const std::vector<Node>& children,
137 const std::vector<Node>& args,
138 CDProof* cdp)
139 {
140 NodeManager* nm = NodeManager::currentNM();
141 Node trueNode = nm->mkConst(true);
142 // get crowding lits and the position of the last clause that includes
143 // them. The factoring step must be added after the last inclusion and before
144 // its elimination.
145 std::unordered_set<TNode, TNodeHashFunction> crowding;
146 std::vector<std::pair<Node, size_t>> lastInclusion;
147 // positions of eliminators of crowding literals, which are the positions of
148 // the clauses that eliminate crowding literals *after* their last inclusion
149 std::vector<size_t> eliminators;
150 for (size_t i = 0, size = clauseLits.size(); i < size; ++i)
151 {
152 if (!crowding.count(clauseLits[i])
153 && std::find(
154 targetClauseLits.begin(), targetClauseLits.end(), clauseLits[i])
155 == targetClauseLits.end())
156 {
157 Node crowdLit = clauseLits[i];
158 crowding.insert(crowdLit);
159 Trace("smt-proof-pp-debug2") << "crowding lit " << crowdLit << "\n";
160 // found crowding lit, now get its last inclusion position, which is the
161 // position of the last resolution link that introduces the crowding
162 // literal. Note that this position has to be *before* the last link, as a
163 // link *after* the last inclusion must eliminate the crowding literal.
164 size_t j;
165 for (j = children.size() - 1; j > 0; --j)
166 {
167 // notice that only non-unit clauses may be introducing the crowding
168 // literal, so we only care about non-unit OR nodes. We check then
169 // against the kind and whether the whole OR node occurs as a pivot of
170 // the respective resolution
171 if (children[j - 1].getKind() != kind::OR)
172 {
173 continue;
174 }
175 uint64_t pivotIndex = 2 * (j - 1);
176 if (args[pivotIndex] == children[j - 1]
177 || args[pivotIndex].notNode() == children[j - 1])
178 {
179 continue;
180 }
181 if (std::find(children[j - 1].begin(), children[j - 1].end(), crowdLit)
182 != children[j - 1].end())
183 {
184 break;
185 }
186 }
187 Assert(j > 0);
188 lastInclusion.emplace_back(crowdLit, j - 1);
189
190 Trace("smt-proof-pp-debug2") << "last inc " << j - 1 << "\n";
191 // get elimination position, starting from the following link as the last
192 // inclusion one. The result is the last (in the chain, but first from
193 // this point on) resolution link that eliminates the crowding literal. A
194 // literal l is eliminated by a link if it contains a literal l' with
195 // opposite polarity to l.
196 for (; j < children.size(); ++j)
197 {
198 bool posFirst = args[(2 * j) - 1] == trueNode;
199 Node pivot = args[(2 * j)];
200 Trace("smt-proof-pp-debug2")
201 << "\tcheck w/ args " << posFirst << " / " << pivot << "\n";
202 // To eliminate the crowding literal (crowdLit), the clause must contain
203 // it with opposite polarity. There are three successful cases,
204 // according to the pivot and its sign
205 //
206 // - crowdLit is the same as the pivot and posFirst is true, which means
207 // that the clause contains its negation and eliminates it
208 //
209 // - crowdLit is the negation of the pivot and posFirst is false, so the
210 // clause contains the node whose negation is crowdLit. Note that this
211 // case may either be crowdLit.notNode() == pivot or crowdLit ==
212 // pivot.notNode().
213 if ((crowdLit == pivot && posFirst)
214 || (crowdLit.notNode() == pivot && !posFirst)
215 || (pivot.notNode() == crowdLit && !posFirst))
216 {
217 Trace("smt-proof-pp-debug2") << "\t\tfound it!\n";
218 eliminators.push_back(j);
219 break;
220 }
221 }
222 AlwaysAssert(j < children.size());
223 }
224 }
225 Assert(!lastInclusion.empty());
226 // order map so that we process crowding literals in the order of the clauses
227 // that last introduce them
228 auto cmp = [](std::pair<Node, size_t>& a, std::pair<Node, size_t>& b) {
229 return a.second < b.second;
230 };
231 std::sort(lastInclusion.begin(), lastInclusion.end(), cmp);
232 // order eliminators
233 std::sort(eliminators.begin(), eliminators.end());
234 if (Trace.isOn("smt-proof-pp-debug"))
235 {
236 Trace("smt-proof-pp-debug") << "crowding lits last inclusion:\n";
237 for (const auto& pair : lastInclusion)
238 {
239 Trace("smt-proof-pp-debug")
240 << "\t- [" << pair.second << "] : " << pair.first << "\n";
241 }
242 Trace("smt-proof-pp-debug") << "eliminators:";
243 for (size_t elim : eliminators)
244 {
245 Trace("smt-proof-pp-debug") << " " << elim;
246 }
247 Trace("smt-proof-pp-debug") << "\n";
248 }
249 // TODO (cvc4-wishues/issues/77): implement also simpler version and compare
250 //
251 // We now start to break the chain, one step at a time. Naively this breaking
252 // down would be one resolution/factoring to each crowding literal, but we can
253 // merge some of the cases. Effectively we do the following:
254 //
255 //
256 // lastClause children[start] ... children[end]
257 // ---------------------------------------------- CHAIN_RES
258 // C
259 // ----------- FACTORING
260 // lastClause' children[start'] ... children[end']
261 // -------------------------------------------------------------- CHAIN_RES
262 // ...
263 //
264 // where
265 // lastClause_0 = children[0]
266 // start_0 = 1
267 // end_0 = eliminators[0] - 1
268 // start_i+1 = nextGuardedElimPos - 1
269 //
270 // The important point is how end_i+1 is computed. It is based on what we call
271 // the "nextGuardedElimPos", i.e., the next elimination position that requires
272 // removal of duplicates. The intuition is that a factoring step may eliminate
273 // the duplicates of crowding literals l1 and l2. If the last inclusion of l2
274 // is before the elimination of l1, then we can go ahead and also perform the
275 // elimination of l2 without another factoring. However if another literal l3
276 // has its last inclusion after the elimination of l2, then the elimination of
277 // l3 is the next guarded elimination.
278 //
279 // To do the above computation then we determine, after a resolution/factoring
280 // step, the first crowded literal to have its last inclusion after "end". The
281 // first elimination position to be bigger than the position of that crowded
282 // literal is the next guarded elimination position.
283 size_t lastElim = 0;
284 Node lastClause = children[0];
285 std::vector<Node> childrenRes;
286 std::vector<Node> childrenResArgs;
287 Node resPlaceHolder;
288 size_t nextGuardedElimPos = eliminators[0];
289 do
290 {
291 size_t start = lastElim + 1;
292 size_t end = nextGuardedElimPos - 1;
293 Trace("smt-proof-pp-debug2")
294 << "res with:\n\tlastClause: " << lastClause << "\n\tstart: " << start
295 << "\n\tend: " << end << "\n";
296 childrenRes.push_back(lastClause);
297 // note that the interval of insert is exclusive in the end, so we add 1
298 childrenRes.insert(childrenRes.end(),
299 children.begin() + start,
300 children.begin() + end + 1);
301 childrenResArgs.insert(childrenResArgs.end(),
302 args.begin() + (2 * start) - 1,
303 args.begin() + (2 * end) + 1);
304 Trace("smt-proof-pp-debug2") << "res children: " << childrenRes << "\n";
305 Trace("smt-proof-pp-debug2") << "res args: " << childrenResArgs << "\n";
306 resPlaceHolder = d_pnm->getChecker()->checkDebug(PfRule::CHAIN_RESOLUTION,
307 childrenRes,
308 childrenResArgs,
309 Node::null(),
310 "");
311 Trace("smt-proof-pp-debug2")
312 << "resPlaceHorder: " << resPlaceHolder << "\n";
313 cdp->addStep(
314 resPlaceHolder, PfRule::CHAIN_RESOLUTION, childrenRes, childrenResArgs);
315 // I need to add factoring if end < children.size(). Otherwise, this is
316 // to be handled by the caller
317 if (end < children.size() - 1)
318 {
319 lastClause = d_pnm->getChecker()->checkDebug(
320 PfRule::FACTORING, {resPlaceHolder}, {}, Node::null(), "");
321 if (!lastClause.isNull())
322 {
323 cdp->addStep(lastClause, PfRule::FACTORING, {resPlaceHolder}, {});
324 }
325 else
326 {
327 lastClause = resPlaceHolder;
328 }
329 Trace("smt-proof-pp-debug2") << "lastClause: " << lastClause << "\n";
330 }
331 else
332 {
333 lastClause = resPlaceHolder;
334 break;
335 }
336 // update for next round
337 childrenRes.clear();
338 childrenResArgs.clear();
339 lastElim = end;
340
341 // find the position of the last inclusion of the next crowded literal
342 size_t nextCrowdedInclusionPos = lastInclusion.size();
343 for (size_t i = 0, size = lastInclusion.size(); i < size; ++i)
344 {
345 if (lastInclusion[i].second > lastElim)
346 {
347 nextCrowdedInclusionPos = i;
348 break;
349 }
350 }
351 Trace("smt-proof-pp-debug2")
352 << "nextCrowdedInclusion/Pos: "
353 << lastInclusion[nextCrowdedInclusionPos].second << "/"
354 << nextCrowdedInclusionPos << "\n";
355 // if there is none, then the remaining literals will be used in the next
356 // round
357 if (nextCrowdedInclusionPos == lastInclusion.size())
358 {
359 nextGuardedElimPos = children.size();
360 }
361 else
362 {
363 nextGuardedElimPos = children.size();
364 for (size_t i = 0, size = eliminators.size(); i < size; ++i)
365 {
366 // nextGuardedElimPos is the largest element of
367 // eliminators bigger the next crowded literal's last inclusion
368 if (eliminators[i] > lastInclusion[nextCrowdedInclusionPos].second)
369 {
370 nextGuardedElimPos = eliminators[i];
371 break;
372 }
373 }
374 Assert(nextGuardedElimPos < children.size());
375 }
376 Trace("smt-proof-pp-debug2")
377 << "nextGuardedElimPos: " << nextGuardedElimPos << "\n";
378 } while (true);
379 return lastClause;
380 }
381
382 Node ProofPostprocessCallback::expandMacros(PfRule id,
383 const std::vector<Node>& children,
384 const std::vector<Node>& args,
385 CDProof* cdp)
386 {
387 if (d_elimRules.find(id) == d_elimRules.end())
388 {
389 // not eliminated
390 return Node::null();
391 }
392 // macro elimination
393 if (id == PfRule::MACRO_SR_EQ_INTRO)
394 {
395 // (TRANS
396 // (SUBS <children> :args args[0:1])
397 // (REWRITE :args <t.substitute(x1,t1). ... .substitute(xn,tn)> args[2]))
398 std::vector<Node> tchildren;
399 Node t = args[0];
400 Node ts;
401 if (!children.empty())
402 {
403 std::vector<Node> sargs;
404 sargs.push_back(t);
405 MethodId sid = MethodId::SB_DEFAULT;
406 if (args.size() >= 2)
407 {
408 if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid))
409 {
410 sargs.push_back(args[1]);
411 }
412 }
413 ts =
414 builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid);
415 Trace("smt-proof-pp-debug")
416 << "...eq intro subs equality is " << t << " == " << ts << ", from "
417 << sid << std::endl;
418 if (ts != t)
419 {
420 Node eq = t.eqNode(ts);
421 // apply SUBS proof rule if necessary
422 if (!updateInternal(eq, PfRule::SUBS, children, sargs, cdp))
423 {
424 // if we specified that we did not want to eliminate, add as step
425 cdp->addStep(eq, PfRule::SUBS, children, sargs);
426 }
427 tchildren.push_back(eq);
428 }
429 }
430 else
431 {
432 // no substitute
433 ts = t;
434 }
435 std::vector<Node> rargs;
436 rargs.push_back(ts);
437 MethodId rid = MethodId::RW_REWRITE;
438 if (args.size() >= 3)
439 {
440 if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid))
441 {
442 rargs.push_back(args[2]);
443 }
444 }
445 builtin::BuiltinProofRuleChecker* builtinPfC =
446 static_cast<builtin::BuiltinProofRuleChecker*>(
447 d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO));
448 Node tr = builtinPfC->applyRewrite(ts, rid);
449 Trace("smt-proof-pp-debug")
450 << "...eq intro rewrite equality is " << ts << " == " << tr << ", from "
451 << rid << std::endl;
452 if (ts != tr)
453 {
454 Node eq = ts.eqNode(tr);
455 // apply REWRITE proof rule
456 if (!updateInternal(eq, PfRule::REWRITE, {}, rargs, cdp))
457 {
458 // if not elimianted, add as step
459 cdp->addStep(eq, PfRule::REWRITE, {}, rargs);
460 }
461 tchildren.push_back(eq);
462 }
463 if (t == tr)
464 {
465 // typically not necessary, but done to be robust
466 cdp->addStep(t.eqNode(tr), PfRule::REFL, {}, {t});
467 return t.eqNode(tr);
468 }
469 // must add TRANS if two step
470 return addProofForTrans(tchildren, cdp);
471 }
472 else if (id == PfRule::MACRO_SR_PRED_INTRO)
473 {
474 std::vector<Node> tchildren;
475 std::vector<Node> sargs = args;
476 // take into account witness form, if necessary
477 bool reqWitness = d_wfpm.requiresWitnessFormIntro(args[0]);
478 Trace("smt-proof-pp-debug")
479 << "...pred intro reqWitness=" << reqWitness << std::endl;
480 // (TRUE_ELIM
481 // (TRANS
482 // (MACRO_SR_EQ_INTRO <children> :args (t args[1:]))
483 // ... proof of apply_SR(t) = toWitness(apply_SR(t)) ...
484 // (MACRO_SR_EQ_INTRO {} {toWitness(apply_SR(t))})
485 // ))
486 // Notice this is an optimized, one sided version of the expansion of
487 // MACRO_SR_PRED_TRANSFORM below.
488 // We call the expandMacros method on MACRO_SR_EQ_INTRO, where notice
489 // that this rule application is immediately expanded in the recursive
490 // call and not added to the proof.
491 Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, children, sargs, cdp);
492 Trace("smt-proof-pp-debug")
493 << "...pred intro conclusion is " << conc << std::endl;
494 Assert(!conc.isNull());
495 Assert(conc.getKind() == EQUAL);
496 Assert(conc[0] == args[0]);
497 tchildren.push_back(conc);
498 if (reqWitness)
499 {
500 Node weq = addProofForWitnessForm(conc[1], cdp);
501 Trace("smt-proof-pp-debug") << "...weq is " << weq << std::endl;
502 if (addToTransChildren(weq, tchildren))
503 {
504 // toWitness(apply_SR(t)) = apply_SR(toWitness(apply_SR(t)))
505 // rewrite again, don't need substitution. Also we always use the
506 // default rewriter, due to the definition of MACRO_SR_PRED_INTRO.
507 Node weqr = expandMacros(PfRule::MACRO_SR_EQ_INTRO, {}, {weq[1]}, cdp);
508 addToTransChildren(weqr, tchildren);
509 }
510 }
511 // apply transitivity if necessary
512 Node eq = addProofForTrans(tchildren, cdp);
513 Assert(!eq.isNull());
514 Assert(eq.getKind() == EQUAL);
515 Assert(eq[0] == args[0]);
516 Assert(eq[1] == d_true);
517
518 cdp->addStep(eq[0], PfRule::TRUE_ELIM, {eq}, {});
519 return eq[0];
520 }
521 else if (id == PfRule::MACRO_SR_PRED_ELIM)
522 {
523 // (EQ_RESOLVE
524 // children[0]
525 // (MACRO_SR_EQ_INTRO children[1:] :args children[0] ++ args))
526 std::vector<Node> schildren(children.begin() + 1, children.end());
527 std::vector<Node> srargs;
528 srargs.push_back(children[0]);
529 srargs.insert(srargs.end(), args.begin(), args.end());
530 Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, srargs, cdp);
531 Assert(!conc.isNull());
532 Assert(conc.getKind() == EQUAL);
533 Assert(conc[0] == children[0]);
534 // apply equality resolve
535 cdp->addStep(conc[1], PfRule::EQ_RESOLVE, {children[0], conc}, {});
536 return conc[1];
537 }
538 else if (id == PfRule::MACRO_SR_PRED_TRANSFORM)
539 {
540 // (EQ_RESOLVE
541 // children[0]
542 // (TRANS
543 // (MACRO_SR_EQ_INTRO children[1:] :args (children[0] args[1:]))
544 // ... proof of c = wc
545 // (MACRO_SR_EQ_INTRO {} wc)
546 // (SYMM
547 // (MACRO_SR_EQ_INTRO children[1:] :args <args>)
548 // ... proof of a = wa
549 // (MACRO_SR_EQ_INTRO {} wa))))
550 // where
551 // wa = toWitness(apply_SR(args[0])) and
552 // wc = toWitness(apply_SR(children[0])).
553 Trace("smt-proof-pp-debug")
554 << "Transform " << children[0] << " == " << args[0] << std::endl;
555 if (CDProof::isSame(children[0], args[0]))
556 {
557 Trace("smt-proof-pp-debug") << "...nothing to do" << std::endl;
558 // nothing to do
559 return children[0];
560 }
561 std::vector<Node> tchildren;
562 std::vector<Node> schildren(children.begin() + 1, children.end());
563 std::vector<Node> sargs = args;
564 // first, compute if we need
565 bool reqWitness = d_wfpm.requiresWitnessFormTransform(children[0], args[0]);
566 Trace("smt-proof-pp-debug") << "...reqWitness=" << reqWitness << std::endl;
567 // convert both sides, in three steps, take symmetry of second chain
568 for (unsigned r = 0; r < 2; r++)
569 {
570 std::vector<Node> tchildrenr;
571 // first rewrite children[0], then args[0]
572 sargs[0] = r == 0 ? children[0] : args[0];
573 // t = apply_SR(t)
574 Node eq = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, sargs, cdp);
575 Trace("smt-proof-pp-debug")
576 << "transform subs_rewrite (" << r << "): " << eq << std::endl;
577 Assert(!eq.isNull() && eq.getKind() == EQUAL && eq[0] == sargs[0]);
578 addToTransChildren(eq, tchildrenr);
579 // apply_SR(t) = toWitness(apply_SR(t))
580 if (reqWitness)
581 {
582 Node weq = addProofForWitnessForm(eq[1], cdp);
583 Trace("smt-proof-pp-debug")
584 << "transform toWitness (" << r << "): " << weq << std::endl;
585 if (addToTransChildren(weq, tchildrenr))
586 {
587 // toWitness(apply_SR(t)) = apply_SR(toWitness(apply_SR(t)))
588 // rewrite again, don't need substitution. Also, we always use the
589 // default rewriter, due to the definition of MACRO_SR_PRED_TRANSFORM.
590 Node weqr =
591 expandMacros(PfRule::MACRO_SR_EQ_INTRO, {}, {weq[1]}, cdp);
592 Trace("smt-proof-pp-debug") << "transform rewrite_witness (" << r
593 << "): " << weqr << std::endl;
594 addToTransChildren(weqr, tchildrenr);
595 }
596 }
597 Trace("smt-proof-pp-debug")
598 << "transform connect (" << r << ")" << std::endl;
599 // add to overall chain
600 if (r == 0)
601 {
602 // add the current chain to the overall chain
603 tchildren.insert(tchildren.end(), tchildrenr.begin(), tchildrenr.end());
604 }
605 else
606 {
607 // add the current chain to cdp
608 Node eqr = addProofForTrans(tchildrenr, cdp);
609 if (!eqr.isNull())
610 {
611 Trace("smt-proof-pp-debug") << "transform connect sym " << tchildren
612 << " " << eqr << std::endl;
613 // take symmetry of above and add it to the overall chain
614 addToTransChildren(eqr, tchildren, true);
615 }
616 }
617 Trace("smt-proof-pp-debug")
618 << "transform finish (" << r << ")" << std::endl;
619 }
620
621 // apply transitivity if necessary
622 Node eq = addProofForTrans(tchildren, cdp);
623
624 cdp->addStep(args[0], PfRule::EQ_RESOLVE, {children[0], eq}, {});
625 return args[0];
626 }
627 else if (id == PfRule::MACRO_RESOLUTION)
628 {
629 // first generate the naive chain_resolution
630 std::vector<Node> chainResArgs{args.begin() + 1, args.end()};
631 Node chainConclusion = d_pnm->getChecker()->checkDebug(
632 PfRule::CHAIN_RESOLUTION, children, chainResArgs, Node::null(), "");
633 Trace("smt-proof-pp-debug") << "Original conclusion: " << args[0] << "\n";
634 Trace("smt-proof-pp-debug")
635 << "chainRes conclusion: " << chainConclusion << "\n";
636 // There are n cases:
637 // - if the conclusion is the same, just replace
638 // - if they have the same literals but in different quantity, add a
639 // FACTORING step
640 // - if the order is not the same, add a REORDERING step
641 // - if there are literals in chainConclusion that are not in the original
642 // conclusion, we need to transform the MACRO_RESOLUTION into a series of
643 // CHAIN_RESOLUTION + FACTORING steps, so that we explicitly eliminate all
644 // these "crowding" literals. We do this via FACTORING so we avoid adding
645 // an exponential number of premises, which would happen if we just
646 // repeated in the premises the clauses needed for eliminating crowding
647 // literals, which could themselves add crowding literals.
648 if (chainConclusion == args[0])
649 {
650 cdp->addStep(
651 chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs);
652 return chainConclusion;
653 }
654 NodeManager* nm = NodeManager::currentNM();
655 // If we got here, then chainConclusion is NECESSARILY an OR node
656 Assert(chainConclusion.getKind() == kind::OR);
657 // get the literals in the chain conclusion
658 std::vector<Node> chainConclusionLits{chainConclusion.begin(),
659 chainConclusion.end()};
660 std::set<Node> chainConclusionLitsSet{chainConclusion.begin(),
661 chainConclusion.end()};
662 // is args[0] a unit clause? If it's not an OR node, then yes. Otherwise,
663 // it's only a unit if it occurs in chainConclusionLitsSet
664 std::vector<Node> conclusionLits;
665 // whether conclusion is unit
666 if (chainConclusionLitsSet.count(args[0]))
667 {
668 conclusionLits.push_back(args[0]);
669 }
670 else
671 {
672 Assert(args[0].getKind() == kind::OR);
673 conclusionLits.insert(
674 conclusionLits.end(), args[0].begin(), args[0].end());
675 }
676 std::set<Node> conclusionLitsSet{conclusionLits.begin(),
677 conclusionLits.end()};
678 // If the sets are different, there are "crowding" literals, i.e. literals
679 // that were removed by implicit multi-usage of premises in the resolution
680 // chain.
681 if (chainConclusionLitsSet != conclusionLitsSet)
682 {
683 chainConclusion = eliminateCrowdingLits(
684 chainConclusionLits, conclusionLits, children, args, cdp);
685 // update vector of lits. Note that the set is no longer used, so we don't
686 // need to update it
687 chainConclusionLits.clear();
688 chainConclusionLits.insert(chainConclusionLits.end(),
689 chainConclusion.begin(),
690 chainConclusion.end());
691 }
692 else
693 {
694 cdp->addStep(
695 chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs);
696 }
697 // Placeholder for running conclusion
698 Node n = chainConclusion;
699 // factoring
700 if (chainConclusionLits.size() != conclusionLits.size())
701 {
702 // We build it rather than taking conclusionLits because the order may be
703 // different
704 std::vector<Node> factoredLits;
705 std::unordered_set<TNode, TNodeHashFunction> clauseSet;
706 for (size_t i = 0, size = chainConclusionLits.size(); i < size; ++i)
707 {
708 if (clauseSet.count(chainConclusionLits[i]))
709 {
710 continue;
711 }
712 factoredLits.push_back(n[i]);
713 clauseSet.insert(n[i]);
714 }
715 Node factored = factoredLits.empty()
716 ? nm->mkConst(false)
717 : factoredLits.size() == 1
718 ? factoredLits[0]
719 : nm->mkNode(kind::OR, factoredLits);
720 cdp->addStep(factored, PfRule::FACTORING, {n}, {});
721 n = factored;
722 }
723 // either same node or n as a clause
724 Assert(n == args[0] || n.getKind() == kind::OR);
725 // reordering
726 if (n != args[0])
727 {
728 cdp->addStep(args[0], PfRule::REORDERING, {n}, {args[0]});
729 }
730 return args[0];
731 }
732 else if (id == PfRule::SUBS)
733 {
734 NodeManager* nm = NodeManager::currentNM();
735 // Notice that a naive way to reconstruct SUBS is to do a term conversion
736 // proof for each substitution.
737 // The proof of f(a) * { a -> g(b) } * { b -> c } = f(g(c)) is:
738 // TRANS( CONG{f}( a=g(b) ), CONG{f}( CONG{g}( b=c ) ) )
739 // Notice that more optimal proofs are possible that do a single traversal
740 // over t. This is done by applying later substitutions to the range of
741 // previous substitutions, until a final simultaneous substitution is
742 // applied to t. For instance, in the above example, we first prove:
743 // CONG{g}( b = c )
744 // by applying the second substitution { b -> c } to the range of the first,
745 // giving us a proof of g(b)=g(c). We then construct the updated proof
746 // by tranitivity:
747 // TRANS( a=g(b), CONG{g}( b=c ) )
748 // We then apply the substitution { a -> g(c), b -> c } to f(a), to obtain:
749 // CONG{f}( TRANS( a=g(b), CONG{g}( b=c ) ) )
750 // which notice is more compact than the proof above.
751 Node t = args[0];
752 // get the kind of substitution
753 MethodId ids = MethodId::SB_DEFAULT;
754 if (args.size() >= 2)
755 {
756 builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids);
757 }
758 std::vector<std::shared_ptr<CDProof>> pfs;
759 std::vector<TNode> vsList;
760 std::vector<TNode> ssList;
761 std::vector<TNode> fromList;
762 std::vector<ProofGenerator*> pgs;
763 // first, compute the entire substitution
764 for (size_t i = 0, nchild = children.size(); i < nchild; i++)
765 {
766 // get the substitution
767 builtin::BuiltinProofRuleChecker::getSubstitutionFor(
768 children[i], vsList, ssList, fromList, ids);
769 // ensure proofs for each formula in fromList
770 if (children[i].getKind() == AND && ids == MethodId::SB_DEFAULT)
771 {
772 for (size_t j = 0, nchildi = children[i].getNumChildren(); j < nchildi;
773 j++)
774 {
775 Node nodej = nm->mkConst(Rational(j));
776 cdp->addStep(
777 children[i][j], PfRule::AND_ELIM, {children[i]}, {nodej});
778 }
779 }
780 }
781 std::vector<Node> vvec;
782 std::vector<Node> svec;
783 for (size_t i = 0, nvs = vsList.size(); i < nvs; i++)
784 {
785 // Note we process in forward order, since later substitution should be
786 // applied to earlier ones, and the last child of a SUBS is processed
787 // first.
788 TNode var = vsList[i];
789 TNode subs = ssList[i];
790 TNode childFrom = fromList[i];
791 Trace("smt-proof-pp-debug")
792 << "...process " << var << " -> " << subs << " (" << childFrom << ", "
793 << ids << ")" << std::endl;
794 // apply the current substitution to the range
795 if (!vvec.empty())
796 {
797 Node ss =
798 subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
799 if (ss != subs)
800 {
801 Trace("smt-proof-pp-debug")
802 << "......updated to " << var << " -> " << ss
803 << " based on previous substitution" << std::endl;
804 // make the proof for the tranitivity step
805 std::shared_ptr<CDProof> pf = std::make_shared<CDProof>(d_pnm);
806 pfs.push_back(pf);
807 // prove the updated substitution
808 TConvProofGenerator tcg(d_pnm,
809 nullptr,
810 TConvPolicy::ONCE,
811 TConvCachePolicy::NEVER,
812 "nested_SUBS_TConvProofGenerator",
813 nullptr,
814 true);
815 // add previous rewrite steps
816 for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
817 {
818 // substitutions are pre-rewrites
819 tcg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
820 }
821 // get the proof for the update to the current substitution
822 Node seqss = subs.eqNode(ss);
823 std::shared_ptr<ProofNode> pfn = tcg.getProofFor(seqss);
824 Assert(pfn != nullptr);
825 // add the proof
826 pf->addProof(pfn);
827 // get proof for childFrom from cdp
828 pfn = cdp->getProofFor(childFrom);
829 pf->addProof(pfn);
830 // ensure we have a proof of var = subs
831 Node veqs = addProofForSubsStep(var, subs, childFrom, pf.get());
832 // transitivity
833 pf->addStep(var.eqNode(ss), PfRule::TRANS, {veqs, seqss}, {});
834 // add to the substitution
835 vvec.push_back(var);
836 svec.push_back(ss);
837 pgs.push_back(pf.get());
838 continue;
839 }
840 }
841 // Just use equality from CDProof, but ensure we have a proof in cdp.
842 // This may involve a TRUE_INTRO/FALSE_INTRO if the substitution step
843 // uses the assumption childFrom as a Boolean assignment (e.g.
844 // childFrom = true if we are using MethodId::SB_LITERAL).
845 addProofForSubsStep(var, subs, childFrom, cdp);
846 vvec.push_back(var);
847 svec.push_back(subs);
848 pgs.push_back(cdp);
849 }
850 Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
851 Node eq = t.eqNode(ts);
852 if (ts != t)
853 {
854 // should be implied by the substitution now
855 TConvProofGenerator tcpg(d_pnm,
856 nullptr,
857 TConvPolicy::ONCE,
858 TConvCachePolicy::NEVER,
859 "SUBS_TConvProofGenerator",
860 nullptr,
861 true);
862 for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
863 {
864 // substitutions are pre-rewrites
865 tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
866 }
867 // add the proof constructed by the term conversion utility
868 std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
869 // should give a proof, if not, then tcpg does not agree with the
870 // substitution.
871 Assert(pfn != nullptr);
872 if (pfn == nullptr)
873 {
874 cdp->addStep(eq, PfRule::TRUST_SUBS, {}, {eq});
875 }
876 else
877 {
878 cdp->addProof(pfn);
879 }
880 }
881 else
882 {
883 // should not be necessary typically
884 cdp->addStep(eq, PfRule::REFL, {}, {t});
885 }
886 return eq;
887 }
888 else if (id == PfRule::REWRITE)
889 {
890 // get the kind of rewrite
891 MethodId idr = MethodId::RW_REWRITE;
892 if (args.size() >= 2)
893 {
894 builtin::BuiltinProofRuleChecker::getMethodId(args[1], idr);
895 }
896 builtin::BuiltinProofRuleChecker* builtinPfC =
897 static_cast<builtin::BuiltinProofRuleChecker*>(
898 d_pnm->getChecker()->getCheckerFor(PfRule::REWRITE));
899 Node ret = builtinPfC->applyRewrite(args[0], idr);
900 Node eq = args[0].eqNode(ret);
901 if (idr == MethodId::RW_REWRITE || idr == MethodId::RW_REWRITE_EQ_EXT)
902 {
903 // rewrites from theory::Rewriter
904 bool isExtEq = (idr == MethodId::RW_REWRITE_EQ_EXT);
905 // use rewrite with proof interface
906 Rewriter* rr = d_smte->getRewriter();
907 TrustNode trn = rr->rewriteWithProof(args[0], isExtEq);
908 std::shared_ptr<ProofNode> pfn = trn.toProofNode();
909 if (pfn == nullptr)
910 {
911 Trace("smt-proof-pp-debug")
912 << "Use TRUST_REWRITE for " << eq << std::endl;
913 // did not have a proof of rewriting, probably isExtEq is true
914 if (isExtEq)
915 {
916 // update to THEORY_REWRITE with idr
917 Assert(args.size() >= 1);
918 TheoryId theoryId = Theory::theoryOf(args[0].getType());
919 Node tid = builtin::BuiltinProofRuleChecker::mkTheoryIdNode(theoryId);
920 cdp->addStep(eq, PfRule::THEORY_REWRITE, {}, {eq, tid, args[1]});
921 }
922 else
923 {
924 // this should never be applied
925 cdp->addStep(eq, PfRule::TRUST_REWRITE, {}, {eq});
926 }
927 }
928 else
929 {
930 cdp->addProof(pfn);
931 }
932 Assert(trn.getNode() == ret)
933 << "Unexpected rewrite " << args[0] << std::endl
934 << "Got: " << trn.getNode() << std::endl
935 << "Expected: " << ret;
936 }
937 else if (idr == MethodId::RW_EVALUATE)
938 {
939 // change to evaluate, which is never eliminated
940 cdp->addStep(eq, PfRule::EVALUATE, {}, {args[0]});
941 }
942 else
943 {
944 // don't know how to eliminate
945 return Node::null();
946 }
947 if (args[0] == ret)
948 {
949 // should not be necessary typically
950 cdp->addStep(eq, PfRule::REFL, {}, {args[0]});
951 }
952 return eq;
953 }
954 else if (id == PfRule::THEORY_REWRITE)
955 {
956 Assert(!args.empty());
957 Node eq = args[0];
958 Assert(eq.getKind() == EQUAL);
959 // try to replay theory rewrite
960 // first, check that maybe its just an evaluation step
961 ProofChecker* pc = d_pnm->getChecker();
962 Node ceval =
963 pc->checkDebug(PfRule::EVALUATE, {}, {eq[0]}, eq, "smt-proof-pp-debug");
964 if (!ceval.isNull() && ceval == eq)
965 {
966 cdp->addStep(eq, PfRule::EVALUATE, {}, {eq[0]});
967 return eq;
968 }
969 // otherwise no update
970 Trace("final-pf-hole") << "hole: " << id << " : " << eq << std::endl;
971 }
972
973 // TRUST, PREPROCESS, THEORY_LEMMA, THEORY_PREPROCESS?
974
975 return Node::null();
976 }
977
978 Node ProofPostprocessCallback::addProofForWitnessForm(Node t, CDProof* cdp)
979 {
980 Node tw = SkolemManager::getWitnessForm(t);
981 Node eq = t.eqNode(tw);
982 if (t == tw)
983 {
984 // not necessary, add REFL step
985 cdp->addStep(eq, PfRule::REFL, {}, {t});
986 return eq;
987 }
988 std::shared_ptr<ProofNode> pn = d_wfpm.getProofFor(eq);
989 if (pn != nullptr)
990 {
991 // add the proof
992 cdp->addProof(pn);
993 }
994 else
995 {
996 Assert(false) << "ProofPostprocessCallback::addProofForWitnessForm: failed "
997 "to add proof for witness form of "
998 << t;
999 }
1000 return eq;
1001 }
1002
1003 Node ProofPostprocessCallback::addProofForTrans(
1004 const std::vector<Node>& tchildren, CDProof* cdp)
1005 {
1006 size_t tsize = tchildren.size();
1007 if (tsize > 1)
1008 {
1009 Node lhs = tchildren[0][0];
1010 Node rhs = tchildren[tsize - 1][1];
1011 Node eq = lhs.eqNode(rhs);
1012 cdp->addStep(eq, PfRule::TRANS, tchildren, {});
1013 return eq;
1014 }
1015 else if (tsize == 1)
1016 {
1017 return tchildren[0];
1018 }
1019 return Node::null();
1020 }
1021
1022 Node ProofPostprocessCallback::addProofForSubsStep(Node var,
1023 Node subs,
1024 Node assump,
1025 CDProof* cdp)
1026 {
1027 // ensure we have a proof of var = subs
1028 Node veqs = var.eqNode(subs);
1029 if (veqs != assump)
1030 {
1031 // should be true intro or false intro
1032 Assert(subs.isConst());
1033 cdp->addStep(
1034 veqs,
1035 subs.getConst<bool>() ? PfRule::TRUE_INTRO : PfRule::FALSE_INTRO,
1036 {assump},
1037 {});
1038 }
1039 return veqs;
1040 }
1041
1042 bool ProofPostprocessCallback::addToTransChildren(Node eq,
1043 std::vector<Node>& tchildren,
1044 bool isSymm)
1045 {
1046 Assert(!eq.isNull());
1047 Assert(eq.getKind() == kind::EQUAL);
1048 if (eq[0] == eq[1])
1049 {
1050 return false;
1051 }
1052 Node equ = isSymm ? eq[1].eqNode(eq[0]) : eq;
1053 Assert(tchildren.empty()
1054 || (tchildren[tchildren.size() - 1].getKind() == kind::EQUAL
1055 && tchildren[tchildren.size() - 1][1] == equ[0]));
1056 tchildren.push_back(equ);
1057 return true;
1058 }
1059
1060 ProofPostprocessFinalCallback::ProofPostprocessFinalCallback(
1061 ProofNodeManager* pnm)
1062 : d_ruleCount("finalProof::ruleCount"),
1063 d_totalRuleCount("finalProof::totalRuleCount", 0),
1064 d_minPedanticLevel("finalProof::minPedanticLevel", 10),
1065 d_numFinalProofs("finalProofs::numFinalProofs", 0),
1066 d_pnm(pnm),
1067 d_pedanticFailure(false)
1068 {
1069 smtStatisticsRegistry()->registerStat(&d_ruleCount);
1070 smtStatisticsRegistry()->registerStat(&d_totalRuleCount);
1071 smtStatisticsRegistry()->registerStat(&d_minPedanticLevel);
1072 smtStatisticsRegistry()->registerStat(&d_numFinalProofs);
1073 }
1074
1075 ProofPostprocessFinalCallback::~ProofPostprocessFinalCallback()
1076 {
1077 smtStatisticsRegistry()->unregisterStat(&d_ruleCount);
1078 smtStatisticsRegistry()->unregisterStat(&d_totalRuleCount);
1079 smtStatisticsRegistry()->unregisterStat(&d_minPedanticLevel);
1080 smtStatisticsRegistry()->unregisterStat(&d_numFinalProofs);
1081 }
1082
1083 void ProofPostprocessFinalCallback::initializeUpdate()
1084 {
1085 d_pedanticFailure = false;
1086 d_pedanticFailureOut.str("");
1087 ++d_numFinalProofs;
1088 }
1089
1090 bool ProofPostprocessFinalCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
1091 bool& continueUpdate)
1092 {
1093 PfRule r = pn->getRule();
1094 // if not doing eager pedantic checking, fail if below threshold
1095 if (!options::proofNewEagerChecking())
1096 {
1097 if (!d_pedanticFailure)
1098 {
1099 Assert(d_pedanticFailureOut.str().empty());
1100 if (d_pnm->getChecker()->isPedanticFailure(r, d_pedanticFailureOut))
1101 {
1102 d_pedanticFailure = true;
1103 }
1104 }
1105 }
1106 uint32_t plevel = d_pnm->getChecker()->getPedanticLevel(r);
1107 if (plevel != 0)
1108 {
1109 d_minPedanticLevel.minAssign(plevel);
1110 }
1111 // record stats for the rule
1112 d_ruleCount << r;
1113 ++d_totalRuleCount;
1114 return false;
1115 }
1116
1117 bool ProofPostprocessFinalCallback::wasPedanticFailure(std::ostream& out) const
1118 {
1119 if (d_pedanticFailure)
1120 {
1121 out << d_pedanticFailureOut.str();
1122 return true;
1123 }
1124 return false;
1125 }
1126
1127 ProofPostproccess::ProofPostproccess(ProofNodeManager* pnm,
1128 SmtEngine* smte,
1129 ProofGenerator* pppg)
1130 : d_pnm(pnm),
1131 d_cb(pnm, smte, pppg),
1132 // the update merges subproofs
1133 d_updater(d_pnm, d_cb, true),
1134 d_finalCb(pnm),
1135 d_finalizer(d_pnm, d_finalCb)
1136 {
1137 }
1138
1139 ProofPostproccess::~ProofPostproccess() {}
1140
1141 void ProofPostproccess::process(std::shared_ptr<ProofNode> pf)
1142 {
1143 // Initialize the callback, which computes necessary static information about
1144 // how to process, including how to process assumptions in pf.
1145 d_cb.initializeUpdate();
1146 // now, process
1147 d_updater.process(pf);
1148 // take stats and check pedantic
1149 d_finalCb.initializeUpdate();
1150 d_finalizer.process(pf);
1151
1152 std::stringstream serr;
1153 bool wasPedanticFailure = d_finalCb.wasPedanticFailure(serr);
1154 if (wasPedanticFailure)
1155 {
1156 AlwaysAssert(!wasPedanticFailure)
1157 << "ProofPostproccess::process: pedantic failure:" << std::endl
1158 << serr.str();
1159 }
1160 }
1161
1162 void ProofPostproccess::setEliminateRule(PfRule rule)
1163 {
1164 d_cb.setEliminateRule(rule);
1165 }
1166
1167 } // namespace smt
1168 } // namespace CVC4