From 109e7e43efdeb557ff17880da83da438db35eb3e Mon Sep 17 00:00:00 2001 From: Haniel Barbosa Date: Fri, 22 Jan 2021 13:15:43 -0300 Subject: [PATCH] [proof-new] Expanding MACRO_RESOLUTION in post-processing (#5755) Breaks down resolution, factoring and reordering. The hardest part of this process is making getting rid of the so-called "crowding literals", i.e., duplicate literals introduced during the series of resolutions and removed implicitly by the SAT solver. A naive removal via addition of premises to the chain resolution can lead to exponential behavior, so instead the removal is done by breaking the resolution and applying a factoring step midway through. This guarantees non-exponential behavior. --- src/smt/proof_manager.cpp | 1 + src/smt/proof_post_processor.cpp | 345 +++++++++++++++++++++++++++++++ src/smt/proof_post_processor.h | 78 +++++++ 3 files changed, 424 insertions(+) diff --git a/src/smt/proof_manager.cpp b/src/smt/proof_manager.cpp index d82e22736..b8f68af88 100644 --- a/src/smt/proof_manager.cpp +++ b/src/smt/proof_manager.cpp @@ -38,6 +38,7 @@ PfManager::PfManager(context::UserContext* u, SmtEngine* smte) d_pfpp->setEliminateRule(PfRule::MACRO_SR_PRED_INTRO); d_pfpp->setEliminateRule(PfRule::MACRO_SR_PRED_ELIM); d_pfpp->setEliminateRule(PfRule::MACRO_SR_PRED_TRANSFORM); + d_pfpp->setEliminateRule(PfRule::MACRO_RESOLUTION); if (options::proofGranularityMode() != options::ProofGranularityMode::REWRITE) { diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 049eb02c0..a620a4d22 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -130,6 +130,246 @@ bool ProofPostprocessCallback::updateInternal(Node res, return update(res, id, children, args, cdp, continueUpdate); } +Node ProofPostprocessCallback::eliminateCrowdingLits( + const std::vector& clauseLits, + const std::vector& targetClauseLits, + const std::vector& children, + const std::vector& args, + CDProof* cdp) +{ + NodeManager* nm = NodeManager::currentNM(); + Node trueNode = nm->mkConst(true); + // get crowding lits and the position of the last clause that includes + // them. The factoring step must be added after the last inclusion and before + // its elimination. + std::unordered_set crowding; + std::vector> lastInclusion; + // positions of eliminators of crowding literals, which are the positions of + // the clauses that eliminate crowding literals *after* their last inclusion + std::vector eliminators; + for (size_t i = 0, size = clauseLits.size(); i < size; ++i) + { + if (!crowding.count(clauseLits[i]) + && std::find( + targetClauseLits.begin(), targetClauseLits.end(), clauseLits[i]) + == targetClauseLits.end()) + { + Node crowdLit = clauseLits[i]; + crowding.insert(crowdLit); + // found crowding lit, now get its last inclusion position, which is the + // position of the last resolution link that introduces the crowding + // literal. Note that this position has to be *before* the last link, as a + // link *after* the last inclusion must eliminate the crowding literal. + size_t j; + for (j = children.size() - 1; j > 0; --j) + { + // notice that only non-unit clauses may be introducing the crowding + // literal, so we don't need to differentiate unit from non-unit + if (children[j - 1].getKind() != kind::OR) + { + continue; + } + if (std::find(children[j - 1].begin(), children[j - 1].end(), crowdLit) + != children[j - 1].end()) + { + break; + } + } + Assert(j > 0); + lastInclusion.emplace_back(crowdLit, j - 1); + Trace("smt-proof-pp-debug2") << "crowding lit " << crowdLit << "\n"; + Trace("smt-proof-pp-debug2") << "last inc " << j - 1 << "\n"; + // get elimination position, starting from the following link as the last + // inclusion one. The result is the last (in the chain, but first from + // this point on) resolution link that eliminates the crowding literal. A + // literal l is eliminated by a link if it contains a literal l' with + // opposite polarity to l. + for (; j < children.size(); ++j) + { + bool posFirst = args[(2 * j) - 1] == trueNode; + Node pivot = args[(2 * j)]; + Trace("smt-proof-pp-debug2") + << "\tcheck w/ args " << posFirst << " / " << pivot << "\n"; + // To eliminate the crowding literal (crowdLit), the clause must contain + // it with opposite polarity. There are three successful cases, + // according to the pivot and its sign + // + // - crowdLit is the same as the pivot and posFirst is true, which means + // that the clause contains its negation and eliminates it + // + // - crowdLit is the negation of the pivot and posFirst is false, so the + // clause contains the node whose negation is crowdLit. Note that this + // case may either be crowdLit.notNode() == pivot or crowdLit == + // pivot.notNode(). + if ((crowdLit == pivot && posFirst) + || (crowdLit.notNode() == pivot && !posFirst) + || (pivot.notNode() == crowdLit && !posFirst)) + { + Trace("smt-proof-pp-debug2") << "\t\tfound it!\n"; + eliminators.push_back(j); + break; + } + } + Assert(j < children.size()); + } + } + Assert(!lastInclusion.empty()); + // order map so that we process crowding literals in the order of the clauses + // that last introduce them + auto cmp = [](std::pair& a, std::pair& b) { + return a.second < b.second; + }; + std::sort(lastInclusion.begin(), lastInclusion.end(), cmp); + // order eliminators + std::sort(eliminators.begin(), eliminators.end()); + if (Trace.isOn("smt-proof-pp-debug")) + { + Trace("smt-proof-pp-debug") << "crowding lits last inclusion:\n"; + for (const auto& pair : lastInclusion) + { + Trace("smt-proof-pp-debug") + << "\t- [" << pair.second << "] : " << pair.first << "\n"; + } + Trace("smt-proof-pp-debug") << "eliminators:"; + for (size_t elim : eliminators) + { + Trace("smt-proof-pp-debug") << " " << elim; + } + Trace("smt-proof-pp-debug") << "\n"; + } + // TODO (cvc4-wishues/issues/77): implement also simpler version and compare + // + // We now start to break the chain, one step at a time. Naively this breaking + // down would be one resolution/factoring to each crowding literal, but we can + // merge some of the cases. Effectively we do the following: + // + // + // lastClause children[start] ... children[end] + // ---------------------------------------------- CHAIN_RES + // C + // ----------- FACTORING + // lastClause' children[start'] ... children[end'] + // -------------------------------------------------------------- CHAIN_RES + // ... + // + // where + // lastClause_0 = children[0] + // start_0 = 1 + // end_0 = eliminators[0] - 1 + // start_i+1 = nextGuardedElimPos - 1 + // + // The important point is how end_i+1 is computed. It is based on what we call + // the "nextGuardedElimPos", i.e., the next elimination position that requires + // removal of duplicates. The intuition is that a factoring step may eliminate + // the duplicates of crowding literals l1 and l2. If the last inclusion of l2 + // is before the elimination of l1, then we can go ahead and also perform the + // elimination of l2 without another factoring. However if another literal l3 + // has its last inclusion after the elimination of l2, then the elimination of + // l3 is the next guarded elimination. + // + // To do the above computation then we determine, after a resolution/factoring + // step, the first crowded literal to have its last inclusion after "end". The + // first elimination position to be bigger than the position of that crowded + // literal is the next guarded elimination position. + size_t lastElim = 0; + Node lastClause = children[0]; + std::vector childrenRes; + std::vector childrenResArgs; + Node resPlaceHolder; + size_t nextGuardedElimPos = eliminators[0]; + do + { + size_t start = lastElim + 1; + size_t end = nextGuardedElimPos - 1; + Trace("smt-proof-pp-debug2") + << "res with:\n\tlastClause: " << lastClause << "\n\tstart: " << start + << "\n\tend: " << end << "\n"; + childrenRes.push_back(lastClause); + // note that the interval of insert is exclusive in the end, so we add 1 + childrenRes.insert(childrenRes.end(), + children.begin() + start, + children.begin() + end + 1); + childrenResArgs.insert(childrenResArgs.end(), + args.begin() + (2 * start) - 1, + args.begin() + (2 * end) + 1); + Trace("smt-proof-pp-debug2") << "res children: " << childrenRes << "\n"; + Trace("smt-proof-pp-debug2") << "res args: " << childrenResArgs << "\n"; + resPlaceHolder = d_pnm->getChecker()->checkDebug(PfRule::CHAIN_RESOLUTION, + childrenRes, + childrenResArgs, + Node::null(), + ""); + Trace("smt-proof-pp-debug2") + << "resPlaceHorder: " << resPlaceHolder << "\n"; + cdp->addStep( + resPlaceHolder, PfRule::CHAIN_RESOLUTION, childrenRes, childrenResArgs); + // I need to add factoring if end < children.size(). Otherwise, this is + // to be handled by the caller + if (end < children.size() - 1) + { + lastClause = d_pnm->getChecker()->checkDebug( + PfRule::FACTORING, {resPlaceHolder}, {}, Node::null(), ""); + if (!lastClause.isNull()) + { + cdp->addStep(lastClause, PfRule::FACTORING, {resPlaceHolder}, {}); + } + else + { + lastClause = resPlaceHolder; + } + Trace("smt-proof-pp-debug2") << "lastClause: " << lastClause << "\n"; + } + else + { + lastClause = resPlaceHolder; + break; + } + // update for next round + childrenRes.clear(); + childrenResArgs.clear(); + lastElim = end; + + // find the position of the last inclusion of the next crowded literal + size_t nextCrowdedInclusionPos = lastInclusion.size(); + for (size_t i = 0, size = lastInclusion.size(); i < size; ++i) + { + if (lastInclusion[i].second > lastElim) + { + nextCrowdedInclusionPos = i; + break; + } + } + Trace("smt-proof-pp-debug2") + << "nextCrowdedInclusion/Pos: " + << lastInclusion[nextCrowdedInclusionPos].second << "/" + << nextCrowdedInclusionPos << "\n"; + // if there is none, then the remaining literals will be used in the next + // round + if (nextCrowdedInclusionPos == lastInclusion.size()) + { + nextGuardedElimPos = children.size(); + } + else + { + nextGuardedElimPos = children.size(); + for (size_t i = 0, size = eliminators.size(); i < size; ++i) + { + // nextGuardedElimPos is the largest element of + // eliminators bigger the next crowded literal's last inclusion + if (eliminators[i] > lastInclusion[nextCrowdedInclusionPos].second) + { + nextGuardedElimPos = eliminators[i]; + break; + } + } + Assert(nextGuardedElimPos < children.size()); + } + Trace("smt-proof-pp-debug2") + << "nextGuardedElimPos: " << nextGuardedElimPos << "\n"; + } while (true); + return lastClause; +} + Node ProofPostprocessCallback::expandMacros(PfRule id, const std::vector& children, const std::vector& args, @@ -375,6 +615,111 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, cdp->addStep(args[0], PfRule::EQ_RESOLVE, {children[0], eq}, {}); return args[0]; } + else if (id == PfRule::MACRO_RESOLUTION) + { + // first generate the naive chain_resolution + std::vector chainResArgs{args.begin() + 1, args.end()}; + Node chainConclusion = d_pnm->getChecker()->checkDebug( + PfRule::CHAIN_RESOLUTION, children, chainResArgs, Node::null(), ""); + Trace("smt-proof-pp-debug") << "Original conclusion: " << args[0] << "\n"; + Trace("smt-proof-pp-debug") + << "chainRes conclusion: " << chainConclusion << "\n"; + // There are n cases: + // - if the conclusion is the same, just replace + // - if they have the same literals but in different quantity, add a + // FACTORING step + // - if the order is not the same, add a REORDERING step + // - if there are literals in chainConclusion that are not in the original + // conclusion, we need to transform the MACRO_RESOLUTION into a series of + // CHAIN_RESOLUTION + FACTORING steps, so that we explicitly eliminate all + // these "crowding" literals. We do this via FACTORING so we avoid adding + // an exponential number of premises, which would happen if we just + // repeated in the premises the clauses needed for eliminating crowding + // literals, which could themselves add crowding literals. + if (chainConclusion == args[0]) + { + cdp->addStep( + chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs); + return chainConclusion; + } + NodeManager* nm = NodeManager::currentNM(); + // If we got here, then chainConclusion is NECESSARILY an OR node + Assert(chainConclusion.getKind() == kind::OR); + // get the literals in the chain conclusion + std::vector chainConclusionLits{chainConclusion.begin(), + chainConclusion.end()}; + std::set chainConclusionLitsSet{chainConclusion.begin(), + chainConclusion.end()}; + // is args[0] a unit clause? If it's not an OR node, then yes. Otherwise, + // it's only a unit if it occurs in chainConclusionLitsSet + std::vector conclusionLits; + // whether conclusion is unit + if (chainConclusionLitsSet.count(args[0])) + { + conclusionLits.push_back(args[0]); + } + else + { + Assert(args[0].getKind() == kind::OR); + conclusionLits.insert( + conclusionLits.end(), args[0].begin(), args[0].end()); + } + std::set conclusionLitsSet{conclusionLits.begin(), + conclusionLits.end()}; + // If the sets are different, there are "crowding" literals, i.e. literals + // that were removed by implicit multi-usage of premises in the resolution + // chain. + if (chainConclusionLitsSet != conclusionLitsSet) + { + chainConclusion = eliminateCrowdingLits( + chainConclusionLits, conclusionLits, children, args, cdp); + // update vector of lits. Note that the set is no longer used, so we don't + // need to update it + chainConclusionLits.clear(); + chainConclusionLits.insert(chainConclusionLits.end(), + chainConclusion.begin(), + chainConclusion.end()); + } + else + { + cdp->addStep( + chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs); + } + // Placeholder for running conclusion + Node n = chainConclusion; + // factoring + if (chainConclusionLits.size() != conclusionLits.size()) + { + // We build it rather than taking conclusionLits because the order may be + // different + std::vector factoredLits; + std::unordered_set clauseSet; + for (size_t i = 0, size = chainConclusionLits.size(); i < size; ++i) + { + if (clauseSet.count(chainConclusionLits[i])) + { + continue; + } + factoredLits.push_back(n[i]); + clauseSet.insert(n[i]); + } + Node factored = factoredLits.empty() + ? nm->mkConst(false) + : factoredLits.size() == 1 + ? factoredLits[0] + : nm->mkNode(kind::OR, factoredLits); + cdp->addStep(factored, PfRule::FACTORING, {n}, {}); + n = factored; + } + // either same node or n as a clause + Assert(n == args[0] || n.getKind() == kind::OR); + // reordering + if (n != args[0]) + { + cdp->addStep(args[0], PfRule::REORDERING, {n}, {args[0]}); + } + return args[0]; + } else if (id == PfRule::SUBS) { NodeManager* nm = NodeManager::currentNM(); diff --git a/src/smt/proof_post_processor.h b/src/smt/proof_post_processor.h index 885608b38..de74d4869 100644 --- a/src/smt/proof_post_processor.h +++ b/src/smt/proof_post_processor.h @@ -149,6 +149,84 @@ class ProofPostprocessCallback : public ProofNodeUpdaterCallback bool addToTransChildren(Node eq, std::vector& tchildren, bool isSymm = false); + + /** + * When given children and args lead to different sets of literals in a + * conclusion depending on whether macro resolution or chain resolution is + * applied, the literals that appear in the chain resolution result, but not + * in the macro resolution result, from now on "crowding literals", are + * literals removed implicitly by macro resolution. For example + * + * l0 v l0 v l0 v l1 v l2 ~l0 v l1 ~l1 + * (1) ----------------------------------------- MACRO_RES + * l2 + * + * but + * + * l0 v l0 v l0 v l1 v l2 ~l0 v l1 ~l1 + * (2) ---------------------------------------- CHAIN_RES + * l0 v l0 v l1 v l2 + * + * where l0 and l1 are crowding literals in the second proof. + * + * There are two views for how MACRO_RES implicitly removes the crowding + * literal, i.e., how MACRO_RES can be expanded into CHAIN_RES so that + * crowding literals are removed. The first is that (1) becomes + * + * l0 v l0 v l0 v l1 v l2 ~l0 v l1 ~l0 v l1 ~l0 v l1 ~l1 ~l1 ~l1 ~l1 + * ---------------------------------------------------------------- CHAIN_RES + * l2 + * + * via the repetition of the premise responsible for removing more than one + * occurrence of the crowding literal. The issue however is that this + * expansion is exponential. Note that (2) has two occurrences of l0 and one + * of l1 as crowding literals. However, by repeating ~l0 v l1 two times to + * remove l0, the clause ~l1, which would originally need to be repeated only + * one time, now has to be repeated two extra times on top of that one. With + * multiple crowding literals and their elimination depending on premises that + * themselves add crowding literals one can easily end up with resolution + * chains going from dozens to thousands of premises. Such examples do occur + * in practice, even in our regressions. + * + * The second way of expanding MACRO_RES, which avoids this exponential + * behavior, is so that (1) becomes + * + * l0 v l0 v l0 v l1 v l2 + * (4) ---------------------- FACTORING + * l0 v l1 v l2 ~l0 v l1 + * ------------------------------------------- CHAIN_RES + * l1 v l1 v l2 + * ------------- FACTORING + * l1 v l2 ~l1 + * ------------------------------ CHAIN_RES + * l2 + * + * This method first determines what are the crowding literals by checking + * what literals occur in clauseLits that do not occur in targetClauseLits + * (the latter contains the literals from the original MACRO_RES conclusion + * while the former the literals from a direct application of CHAIN_RES). Then + * it builds a proof such as (4) and adds the steps to cdp. The final + * conclusion is returned. + * + * Note that in the example the CHAIN_RES steps introduced had only two + * premises, and could thus be replaced by a RESOLUTION step, but since we + * general there can be more than two premises we always use CHAIN_RES. + * + * @param clauseLits literals in the conclusion of a CHAIN_RESOLUTION step + * with children and args[1:] + * @param clauseLits literals in the conclusion of a MACRO_RESOLUTION step + * with children and args + * @param children a list of clauses + * @param args a list of arguments to a MACRO_RESOLUTION step + * @param cdp a CDProof + * @return The resulting node of transforming MACRO_RESOLUTION into + * CHAIN_RESOLUTION according to the above idea. + */ + Node eliminateCrowdingLits(const std::vector& clauseLits, + const std::vector& targetClauseLits, + const std::vector& children, + const std::vector& args, + CDProof* cdp); }; /** Final callback class, for stats and pedantic checking */ -- 2.30.2