From: Aina Niemetz Date: Thu, 2 Sep 2021 03:44:06 +0000 (-0700) Subject: pp: Derive PreprocessingPass from EnvObj. (#7112) X-Git-Tag: cvc5-1.0.0~1296 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=eb472264003f202b842e4d25355a1c3f01467750;p=cvc5.git pp: Derive PreprocessingPass from EnvObj. (#7112) --- diff --git a/src/preprocessing/passes/ite_simp.cpp b/src/preprocessing/passes/ite_simp.cpp index 5b13be619..54e2b657e 100644 --- a/src/preprocessing/passes/ite_simp.cpp +++ b/src/preprocessing/passes/ite_simp.cpp @@ -24,7 +24,6 @@ #include "smt/smt_statistics_registry.h" #include "smt_util/nary_builder.h" #include "theory/arith/arith_ite_utils.h" -#include "theory/rewriter.h" #include "theory/theory_engine.h" using namespace std; @@ -39,33 +38,6 @@ namespace passes { namespace { -Node simpITE(util::ITEUtilities* ite_utils, TNode assertion) -{ - if (!ite_utils->containsTermITE(assertion)) - { - return assertion; - } - else - { - Node result = ite_utils->simpITE(assertion); - Node res_rewritten = Rewriter::rewrite(result); - - if (options::simplifyWithCareEnabled()) - { - Chat() << "starting simplifyWithCare()" << endl; - Node postSimpWithCare = ite_utils->simplifyWithCare(res_rewritten); - Chat() << "ending simplifyWithCare()" - << " post simplifyWithCare()" << postSimpWithCare.getId() << endl; - result = Rewriter::rewrite(postSimpWithCare); - } - else - { - result = res_rewritten; - } - return result; - } -} - /** * Ensures the assertions asserted after index 'before' now effectively come * before 'real_assertions_end'. @@ -114,6 +86,33 @@ ITESimp::Statistics::Statistics() { } +Node ITESimp::simpITE(util::ITEUtilities* ite_utils, TNode assertion) +{ + if (!ite_utils->containsTermITE(assertion)) + { + return assertion; + } + else + { + Node result = ite_utils->simpITE(assertion); + Node res_rewritten = rewrite(result); + + if (options::simplifyWithCareEnabled()) + { + Chat() << "starting simplifyWithCare()" << endl; + Node postSimpWithCare = ite_utils->simplifyWithCare(res_rewritten); + Chat() << "ending simplifyWithCare()" + << " post simplifyWithCare()" << postSimpWithCare.getId() << endl; + result = rewrite(postSimpWithCare); + } + else + { + result = res_rewritten; + } + return result; + } +} + bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) { bool result = true; @@ -136,7 +135,7 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) Chat() << "....node manager contains " << nm->poolSize() << " nodes before cleanup" << endl; d_iteUtilities.clear(); - d_preprocContext->getRewriter()->clearCaches(); + d_env.getRewriter()->clearCaches(); nm->reclaimZombiesUntil(options::zombieHuntThreshold()); Chat() << "....node manager contains " << nm->poolSize() << " nodes after cleanup" << endl; @@ -145,8 +144,7 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) } // Do theory specific preprocessing passes - TheoryEngine* theory_engine = d_preprocContext->getTheoryEngine(); - if (theory_engine->getLogicInfo().isTheoryEnabled(theory::THEORY_ARITH) + if (d_env.getLogicInfo().isTheoryEnabled(theory::THEORY_ARITH) && !options::incrementalSolving()) { if (!simpDidALotOfWork) @@ -170,7 +168,7 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) { Node more = aiteu.reduceConstantIteByGCD(res); Debug("arith::ite::red") << " gcd->" << more << endl; - Node morer = Rewriter::rewrite(more); + Node morer = rewrite(more); assertionsToPreprocess->replace(i, morer); } } @@ -187,7 +185,7 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) for (size_t i = 0, N = assertionsToPreprocess->size(); i < N; ++i) { Node curr = (*assertionsToPreprocess)[i]; - Node next = Rewriter::rewrite(aiteu.applySubstitutions(curr)); + Node next = rewrite(aiteu.applySubstitutions(curr)); Node res = aiteu.reduceVariablesInItes(next); Debug("arith::ite::red") << "@ " << i << " ... " << next << endl << " ->" << res << endl; @@ -204,13 +202,13 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) ++i) { Node curr = (*assertionsToPreprocess)[i]; - Node next = Rewriter::rewrite(aiteu.applySubstitutions(curr)); + Node next = rewrite(aiteu.applySubstitutions(curr)); Node res = aiteu.reduceVariablesInItes(next); Debug("arith::ite::red") << "@ " << i << " ... " << next << endl << " ->" << res << endl; Node more = aiteu.reduceConstantIteByGCD(res); Debug("arith::ite::red") << " gcd->" << more << endl; - Node morer = Rewriter::rewrite(more); + Node morer = rewrite(more); assertionsToPreprocess->replace(i, morer); } } @@ -223,7 +221,7 @@ bool ITESimp::doneSimpITE(AssertionPipeline* assertionsToPreprocess) /* -------------------------------------------------------------------------- */ ITESimp::ITESimp(PreprocessingPassContext* preprocContext) - : PreprocessingPass(preprocContext, "ite-simp") + : PreprocessingPass(preprocContext, "ite-simp"), d_iteUtilities(d_env) { } diff --git a/src/preprocessing/passes/ite_simp.h b/src/preprocessing/passes/ite_simp.h index b342301f5..ca8547d46 100644 --- a/src/preprocessing/passes/ite_simp.h +++ b/src/preprocessing/passes/ite_simp.h @@ -42,6 +42,7 @@ class ITESimp : public PreprocessingPass Statistics(); }; + Node simpITE(util::ITEUtilities* ite_utils, TNode assertion); bool doneSimpITE(AssertionPipeline *assertionsToPreprocesss); /** A collection of ite preprocessing passes. */ diff --git a/src/preprocessing/preprocessing_pass.cpp b/src/preprocessing/preprocessing_pass.cpp index 3b2a088a2..f38f1e823 100644 --- a/src/preprocessing/preprocessing_pass.cpp +++ b/src/preprocessing/preprocessing_pass.cpp @@ -58,10 +58,11 @@ void PreprocessingPass::dumpAssertions(const char* key, PreprocessingPass::PreprocessingPass(PreprocessingPassContext* preprocContext, const std::string& name) - : d_name(name), + : EnvObj(preprocContext->getEnv()), + d_preprocContext(preprocContext), + d_name(name), d_timer(smtStatisticsRegistry().registerTimer("preprocessing::" + name)) { - d_preprocContext = preprocContext; } PreprocessingPass::~PreprocessingPass() { diff --git a/src/preprocessing/preprocessing_pass.h b/src/preprocessing/preprocessing_pass.h index 14824ee4b..e1c3a5100 100644 --- a/src/preprocessing/preprocessing_pass.h +++ b/src/preprocessing/preprocessing_pass.h @@ -34,6 +34,7 @@ #include +#include "smt/env_obj.h" #include "util/statistics_stats.h" namespace cvc5 { @@ -48,7 +49,8 @@ class PreprocessingPassContext; */ enum PreprocessingPassResult { CONFLICT, NO_CONFLICT }; -class PreprocessingPass { +class PreprocessingPass : public EnvObj +{ public: /* Preprocesses a list of assertions assertionsToPreprocess */ PreprocessingPassResult apply(AssertionPipeline* assertionsToPreprocess); diff --git a/src/preprocessing/preprocessing_pass_context.h b/src/preprocessing/preprocessing_pass_context.h index a2407b953..65804b2b5 100644 --- a/src/preprocessing/preprocessing_pass_context.h +++ b/src/preprocessing/preprocessing_pass_context.h @@ -58,7 +58,8 @@ class PreprocessingPassContext /** Get the associated SmtEngine. */ SmtEngine* getSmt() const { return d_smt; } - + /** Get the associated Environment. */ + Env& getEnv() { return d_env; } /** Get the associated TheoryEngine. */ TheoryEngine* getTheoryEngine() const; /** Get the associated Propengine. */ diff --git a/src/preprocessing/util/ite_utilities.cpp b/src/preprocessing/util/ite_utilities.cpp index 13910f534..9f59ad397 100644 --- a/src/preprocessing/util/ite_utilities.cpp +++ b/src/preprocessing/util/ite_utilities.cpp @@ -26,7 +26,6 @@ #include "preprocessing/assertion_pipeline.h" #include "preprocessing/passes/rewrite.h" #include "smt/smt_statistics_registry.h" -#include "theory/rewriter.h" #include "theory/theory.h" #include "util/rational.h" @@ -87,8 +86,9 @@ struct CTIVStackElement } // namespace ite -ITEUtilities::ITEUtilities() - : d_containsVisitor(new ContainsTermITEVisitor()), +ITEUtilities::ITEUtilities(Env& env) + : EnvObj(env), + d_containsVisitor(new ContainsTermITEVisitor()), d_compressor(NULL), d_simplifier(NULL), d_careSimp(NULL) @@ -116,7 +116,7 @@ Node ITEUtilities::simpITE(TNode assertion) { if (d_simplifier == NULL) { - d_simplifier = new ITESimplifier(d_containsVisitor.get()); + d_simplifier = new ITESimplifier(d_env, d_containsVisitor.get()); } return d_simplifier->simpITE(assertion); } @@ -138,7 +138,7 @@ bool ITEUtilities::compress(AssertionPipeline* assertionsToPreprocess) { if (d_compressor == NULL) { - d_compressor = new ITECompressor(d_containsVisitor.get()); + d_compressor = new ITECompressor(d_env, d_containsVisitor.get()); } return d_compressor->compress(assertionsToPreprocess); } @@ -288,8 +288,11 @@ void IncomingArcCounter::computeReachability( void IncomingArcCounter::clear() { d_reachCount.clear(); } /** ITECompressor. */ -ITECompressor::ITECompressor(ContainsTermITEVisitor* contains) - : d_contains(contains), d_assertions(NULL), d_incoming(true, true) +ITECompressor::ITECompressor(Env& env, ContainsTermITEVisitor* contains) + : EnvObj(env), + d_contains(contains), + d_assertions(NULL), + d_incoming(true, true) { Assert(d_contains != NULL); @@ -316,7 +319,7 @@ ITECompressor::Statistics::Statistics() Node ITECompressor::push_back_boolean(Node original, Node compressed) { - Node rewritten = theory::Rewriter::rewrite(compressed); + Node rewritten = rewrite(compressed); // There is a bug if the rewritter takes a pure boolean expression // and changes its theory if (rewritten.isConst()) @@ -538,7 +541,7 @@ bool ITECompressor::compress(AssertionPipeline* assertionsToPreprocess) { Node assertion = assertions[i]; Node compressed = compressBoolean(assertion); - Node rewritten = theory::Rewriter::rewrite(compressed); + Node rewritten = rewrite(compressed); // replace assertionsToPreprocess->replace(i, rewritten); Assert(!d_contains->containsTermITE(rewritten)); @@ -637,8 +640,9 @@ uint32_t TermITEHeightCounter::termITEHeight(TNode e) return returnValue; } -ITESimplifier::ITESimplifier(ContainsTermITEVisitor* contains) - : d_containsVisitor(contains), +ITESimplifier::ITESimplifier(Env& env, ContainsTermITEVisitor* contains) + : EnvObj(env), + d_containsVisitor(contains), d_termITEHeight(), d_constantLeaves(), d_allocatedConstantLeaves(), @@ -1325,15 +1329,14 @@ Node ITESimplifier::simpConstants(TNode simpContext, } // Mark the substitution and continue Node result = builder; - result = theory::Rewriter::rewrite(result); + result = rewrite(result); d_simpConstCache[pair(simpContext, iteNode)] = result; return result; } if (!containsTermITE(iteNode)) { - Node n = - theory::Rewriter::rewrite(simpContext.substitute(simpVar, iteNode)); + Node n = rewrite(simpContext.substitute(simpVar, iteNode)); d_simpConstCache[pair(simpContext, iteNode)] = n; return n; } @@ -1459,7 +1462,7 @@ Node ITESimplifier::simpITEAtom(TNode atom) Debug("ite::atom") << " finished " << instance << endl; if (!attempt.isNull()) { - Node rewritten = theory::Rewriter::rewrite(attempt); + Node rewritten = rewrite(attempt); Debug("ite::print-success") << instance << " " << "rewriting " << countReachable(rewritten, kind::ITE) << " from " @@ -1485,7 +1488,7 @@ Node ITESimplifier::simpITEAtom(TNode atom) << "how about?" << atom << endl; Debug("ite::simpite") << instance << " " << "\t" << simpContext << endl; - return theory::Rewriter::rewrite(simpContext); + return rewrite(simpContext); } Node n = simpConstants(simpContext, iteNode, simpVar); if (!n.isNull()) @@ -1588,7 +1591,7 @@ Node ITESimplifier::simpITE(TNode assertion) // //cout << instance << " " << result << current << endl; // } - result = theory::Rewriter::rewrite(result); + result = rewrite(result); d_simpITECache[current] = result; ++(d_statistics.d_simpITEVisits); toVisit.pop_back(); diff --git a/src/preprocessing/util/ite_utilities.h b/src/preprocessing/util/ite_utilities.h index 3ab655767..48408e46e 100644 --- a/src/preprocessing/util/ite_utilities.h +++ b/src/preprocessing/util/ite_utilities.h @@ -28,10 +28,12 @@ #include #include "expr/node.h" +#include "smt/env_obj.h" #include "util/hash.h" #include "util/statistics_stats.h" namespace cvc5 { + namespace preprocessing { class AssertionPipeline; @@ -65,10 +67,10 @@ class ContainsTermITEVisitor NodeBoolMap d_cache; }; -class ITEUtilities +class ITEUtilities : public EnvObj { public: - ITEUtilities(); + ITEUtilities(Env& env); ~ITEUtilities(); Node simpITE(TNode assertion); @@ -162,10 +164,10 @@ class TermITEHeightCounter * A routine designed to undo the potentially large blow up * due to expansion caused by the ite simplifier. */ -class ITECompressor +class ITECompressor : public EnvObj { public: - ITECompressor(ContainsTermITEVisitor* contains); + ITECompressor(Env& env, ContainsTermITEVisitor* contains); ~ITECompressor(); /* returns false if an assertion is discovered to be equal to false. */ @@ -175,14 +177,13 @@ class ITECompressor void garbageCollect(); private: - Node d_true; /* Copy of true. */ - Node d_false; /* Copy of false. */ - ContainsTermITEVisitor* d_contains; - AssertionPipeline* d_assertions; - IncomingArcCounter d_incoming; - - typedef std::unordered_map NodeMap; - NodeMap d_compressed; + class Statistics + { + public: + IntStat d_compressCalls; + IntStat d_skolemsAdded; + Statistics(); + }; void reset(); @@ -192,20 +193,23 @@ class ITECompressor Node compressTerm(Node toCompress); Node compressBoolean(Node toCompress); - class Statistics - { - public: - IntStat d_compressCalls; - IntStat d_skolemsAdded; - Statistics(); - }; + Node d_true; /* Copy of true. */ + Node d_false; /* Copy of false. */ + + ContainsTermITEVisitor* d_contains; + AssertionPipeline* d_assertions; + IncomingArcCounter d_incoming; + + typedef std::unordered_map NodeMap; + NodeMap d_compressed; + Statistics d_statistics; }; /* class ITECompressor */ -class ITESimplifier +class ITESimplifier : public EnvObj { public: - ITESimplifier(ContainsTermITEVisitor* d_containsVisitor); + ITESimplifier(Env& env, ContainsTermITEVisitor* d_containsVisitor); ~ITESimplifier(); Node simpITE(TNode assertion); @@ -214,27 +218,39 @@ class ITESimplifier void clearSimpITECaches(); private: - Node d_true; - Node d_false; + using NodeVec = std::vector; + using ConstantLeavesMap = std::unordered_map; + using NodePair = std::pair; + using NodePairHashFunction = + PairHashFunction, std::hash>; + using NodePairMap = std::unordered_map; + + class Statistics + { + public: + IntStat d_maxNonConstantsFolded; + IntStat d_unexpected; + IntStat d_unsimplified; + IntStat d_exactMatchFold; + IntStat d_binaryPredFold; + IntStat d_specialEqualityFolds; + IntStat d_simpITEVisits; + + HistogramStat d_inSmaller; + + Statistics(); + }; - ContainsTermITEVisitor* d_containsVisitor; inline bool containsTermITE(TNode n) { return d_containsVisitor->containsTermITE(n); } - TermITEHeightCounter d_termITEHeight; + inline uint32_t termITEHeight(TNode e) { return d_termITEHeight.termITEHeight(e); } - // ConstantIte is a small inductive sublanguage: - // constant - // or termITE(cnd, ConstantIte, ConstantIte) - typedef std::vector NodeVec; - typedef std::unordered_map ConstantLeavesMap; - ConstantLeavesMap d_constantLeaves; - // d_constantLeaves satisfies the following invariants: // not containsTermITE(x) then !isKey(x) // containsTermITE(x): @@ -248,9 +264,6 @@ class ITESimplifier * returns a sorted NodeVec of the leaves. */ NodeVec* computeConstantLeaves(TNode ite); - // Lists all of the vectors in d_constantLeaves for fast deletion. - std::vector d_allocatedConstantLeaves; - /* transforms */ Node transformAtom(TNode atom); Node attemptConstantRemoval(TNode atom); @@ -264,51 +277,50 @@ class ITESimplifier // Given ConstantIte tree cite and a constant c, // return a boolean expression equivalent to (= lcite c) Node constantIteEqualsConstant(TNode cite, TNode c); + + Node replaceOver(Node n, Node replaceWith, Node simpVar); + Node replaceOverTermIte(Node term, Node simpAtom, Node simpVar); + + bool leavesAreConst(TNode e, theory::TheoryId tid); + bool leavesAreConst(TNode e); + + Node simpConstants(TNode simpContext, TNode iteNode, TNode simpVar); + + Node createSimpContext(TNode c, Node& iteNode, Node& simpVar); + + Node d_true; + Node d_false; + + ContainsTermITEVisitor* d_containsVisitor; + + TermITEHeightCounter d_termITEHeight; + + // ConstantIte is a small inductive sublanguage: + // constant + // or termITE(cnd, ConstantIte, ConstantIte) + ConstantLeavesMap d_constantLeaves; + + // Lists all of the vectors in d_constantLeaves for fast deletion. + std::vector d_allocatedConstantLeaves; + uint32_t d_citeEqConstApplications; - typedef std::pair NodePair; - using NodePairHashFunction = - PairHashFunction, std::hash>; - typedef std::unordered_map NodePairMap; NodePairMap d_constantIteEqualsConstantCache; NodePairMap d_replaceOverCache; NodePairMap d_replaceOverTermIteCache; - Node replaceOver(Node n, Node replaceWith, Node simpVar); - Node replaceOverTermIte(Node term, Node simpAtom, Node simpVar); std::unordered_map d_leavesConstCache; - bool leavesAreConst(TNode e, theory::TheoryId tid); - bool leavesAreConst(TNode e); NodePairMap d_simpConstCache; - Node simpConstants(TNode simpContext, TNode iteNode, TNode simpVar); std::unordered_map d_simpVars; Node getSimpVar(TypeNode t); typedef std::unordered_map NodeMap; NodeMap d_simpContextCache; - Node createSimpContext(TNode c, Node& iteNode, Node& simpVar); NodeMap d_simpITECache; Node simpITEAtom(TNode atom); - private: - class Statistics - { - public: - IntStat d_maxNonConstantsFolded; - IntStat d_unexpected; - IntStat d_unsimplified; - IntStat d_exactMatchFold; - IntStat d_binaryPredFold; - IntStat d_specialEqualityFolds; - IntStat d_simpITEVisits; - - HistogramStat d_inSmaller; - - Statistics(); - }; - Statistics d_statistics; }; diff --git a/test/unit/preprocessing/pass_bv_gauss_white.cpp b/test/unit/preprocessing/pass_bv_gauss_white.cpp index a8c1fae25..8f6fa7b14 100644 --- a/test/unit/preprocessing/pass_bv_gauss_white.cpp +++ b/test/unit/preprocessing/pass_bv_gauss_white.cpp @@ -21,6 +21,7 @@ #include "expr/node_manager.h" #include "preprocessing/assertion_pipeline.h" #include "preprocessing/passes/bv_gauss.h" +#include "preprocessing/preprocessing_pass_context.h" #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "test_smt.h" @@ -44,6 +45,9 @@ class TestPPWhiteBVGauss : public TestSmt { TestSmt::SetUp(); + d_preprocContext.reset(new preprocessing::PreprocessingPassContext( + d_smtEngine.get(), d_smtEngine->getEnv(), nullptr)); + d_zero = bv::utils::mkZero(16); d_p = bv::utils::mkConcat( @@ -194,6 +198,8 @@ class TestPPWhiteBVGauss : public TestSmt } } + std::unique_ptr d_preprocContext; + Node d_p; Node d_x; Node d_y; @@ -2399,7 +2405,7 @@ TEST_F(TestPPWhiteBVGauss, elim_rewrite_unique1) AssertionPipeline apipe; apipe.push_back(a); - passes::BVGauss bgauss(nullptr, "bv-gauss-unit"); + passes::BVGauss bgauss(d_preprocContext.get(), "bv-gauss-unit"); std::unordered_map res; PreprocessingPassResult pres = bgauss.applyInternal(&apipe); ASSERT_EQ(pres, PreprocessingPassResult::NO_CONFLICT); @@ -2488,7 +2494,7 @@ TEST_F(TestPPWhiteBVGauss, elim_rewrite_unique2) apipe.push_back(a); apipe.push_back(eq4); apipe.push_back(eq5); - passes::BVGauss bgauss(nullptr, "bv-gauss-unit"); + passes::BVGauss bgauss(d_preprocContext.get(), "bv-gauss-unit"); std::unordered_map res; PreprocessingPassResult pres = bgauss.applyInternal(&apipe); ASSERT_EQ(pres, PreprocessingPassResult::NO_CONFLICT); @@ -2539,7 +2545,7 @@ TEST_F(TestPPWhiteBVGauss, elim_rewrite_partial) AssertionPipeline apipe; apipe.push_back(eq1); apipe.push_back(eq2); - passes::BVGauss bgauss(nullptr, "bv-gauss-unit"); + passes::BVGauss bgauss(d_preprocContext.get(), "bv-gauss-unit"); std::unordered_map res; PreprocessingPassResult pres = bgauss.applyInternal(&apipe); ASSERT_EQ(pres, PreprocessingPassResult::NO_CONFLICT);