From 991af9a7a73adaa84712e93af72980ba977b1155 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 20 Aug 2018 12:21:37 -0500 Subject: [PATCH] Make sygus inference a preprocessing pass (#2334) --- src/Makefile.am | 4 +- .../passes}/sygus_inference.cpp | 117 ++++++++++-------- .../passes}/sygus_inference.h | 52 +++++--- src/smt/smt_engine.cpp | 13 +- 4 files changed, 109 insertions(+), 77 deletions(-) rename src/{theory/quantifiers => preprocessing/passes}/sygus_inference.cpp (77%) rename src/{theory/quantifiers => preprocessing/passes}/sygus_inference.h (51%) diff --git a/src/Makefile.am b/src/Makefile.am index 5e52186b9..40aa1a5af 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -99,6 +99,8 @@ libcvc4_la_SOURCES = \ preprocessing/passes/sort_infer.h \ preprocessing/passes/static_learning.cpp \ preprocessing/passes/static_learning.h \ + preprocessing/passes/sygus_inference.cpp \ + preprocessing/passes/sygus_inference.h \ preprocessing/passes/symmetry_breaker.cpp \ preprocessing/passes/symmetry_breaker.h \ preprocessing/passes/symmetry_detect.cpp \ @@ -542,8 +544,6 @@ libcvc4_la_SOURCES = \ theory/quantifiers/sygus/sygus_unif_strat.h \ theory/quantifiers/sygus/term_database_sygus.cpp \ theory/quantifiers/sygus/term_database_sygus.h \ - theory/quantifiers/sygus_inference.cpp \ - theory/quantifiers/sygus_inference.h \ theory/quantifiers/sygus_sampler.cpp \ theory/quantifiers/sygus_sampler.h \ theory/quantifiers/term_database.cpp \ diff --git a/src/theory/quantifiers/sygus_inference.cpp b/src/preprocessing/passes/sygus_inference.cpp similarity index 77% rename from src/theory/quantifiers/sygus_inference.cpp rename to src/preprocessing/passes/sygus_inference.cpp index 6232de6fe..eb8835623 100644 --- a/src/theory/quantifiers/sygus_inference.cpp +++ b/src/preprocessing/passes/sygus_inference.cpp @@ -9,28 +9,78 @@ ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief Implementation of sygus_inference + ** \brief Sygus inference module **/ -#include "theory/quantifiers/sygus_inference.h" +#include "preprocessing/passes/sygus_inference.h" + #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/quantifiers_rewriter.h" +using namespace std; using namespace CVC4::kind; namespace CVC4 { -namespace theory { -namespace quantifiers { +namespace preprocessing { +namespace passes { -SygusInference::SygusInference() {} +SygusInference::SygusInference(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "sygus-infer"){}; -bool SygusInference::simplify(std::vector& assertions) +PreprocessingPassResult SygusInference::applyInternal( + AssertionPipeline* assertionsToPreprocess) { Trace("sygus-infer") << "Run sygus inference..." << std::endl; + std::vector funs; + std::vector sols; + // see if we can succesfully solve the input as a sygus problem + if (solveSygus(assertionsToPreprocess->ref(), funs, sols)) + { + Assert(funs.size() == sols.size()); + // if so, sygus gives us function definitions + SmtEngine* master_smte = smt::currentSmtEngine(); + for (unsigned i = 0, size = funs.size(); i < size; i++) + { + std::vector args; + Node sol = sols[i]; + // if it is a non-constant function + if (sol.getKind() == LAMBDA) + { + for (const Node& v : sol[0]) + { + args.push_back(v.toExpr()); + } + sol = sol[1]; + } + master_smte->defineFunction(funs[i].toExpr(), args, sol.toExpr()); + } + // apply substitution to everything, should result in SAT + for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size; + i++) + { + Node prev = (*assertionsToPreprocess)[i]; + Node curr = + prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end()); + if (curr != prev) + { + curr = theory::Rewriter::rewrite(curr); + Trace("sygus-infer-debug") + << "...rewrote " << prev << " to " << curr << std::endl; + assertionsToPreprocess->replace(i, curr); + } + } + } + return PreprocessingPassResult::NO_CONFLICT; +} + +bool SygusInference::solveSygus(std::vector& assertions, + std::vector& funs, + std::vector& sols) +{ if (assertions.empty()) { Trace("sygus-infer") << "...fail: empty assertions." << std::endl; @@ -78,19 +128,19 @@ bool SygusInference::simplify(std::vector& assertions) std::map type_count; Node pas = as; // rewrite - pas = Rewriter::rewrite(pas); + pas = theory::Rewriter::rewrite(pas); Trace("sygus-infer") << "assertion : " << pas << std::endl; if (pas.getKind() == FORALL) { // preprocess the quantified formula - pas = quantifiers::QuantifiersRewriter::preprocess(pas); + pas = theory::quantifiers::QuantifiersRewriter::preprocess(pas); Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl; } if (pas.getKind() == FORALL) { // it must be a standard quantifier - QAttributes qa; - QuantAttributes::computeQuantAttributes(pas, qa); + theory::quantifiers::QAttributes qa; + theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa); if (!qa.isStandard()) { Trace("sygus-infer") @@ -215,7 +265,7 @@ bool SygusInference::simplify(std::vector& assertions) // sygus attribute to mark the conjecture as a sygus conjecture Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl; Node sygusVar = nm->mkSkolem("sygus", nm->booleanType()); - SygusAttribute ca; + theory::SygusAttribute ca; sygusVar.setAttribute(ca, true); Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar); Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr); @@ -227,7 +277,6 @@ bool SygusInference::simplify(std::vector& assertions) Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl; // make a separate smt call - SmtEngine* master_smte = smt::currentSmtEngine(); SmtEngine rrSygus(nm->toExprManager()); rrSygus.setLogic(smt::currentSmtEngine()->getLogicInfo()); rrSygus.assertFormula(body.toExpr()); @@ -249,7 +298,6 @@ bool SygusInference::simplify(std::vector& assertions) it != synth_sols.end(); ++it) { - Node lambda = Node::fromExpr(it->second); Trace("sygus-infer") << " synth sol : " << it->first << " -> " << it->second << std::endl; Node ffv = Node::fromExpr(it->first); @@ -259,44 +307,15 @@ bool SygusInference::simplify(std::vector& assertions) if (itffv != ff_var_to_ff.end()) { Node ff = itffv->second; - Expr body = it->second; - std::vector args; - // if it is a non-constant function - if (lambda.getKind() == LAMBDA) - { - for (const Node& v : lambda[0]) - { - args.push_back(v.toExpr()); - } - body = it->second[1]; - } - Trace("sygus-infer") << "Define " << ff << " as " << it->second - << std::endl; - final_ff.push_back(ff); - final_ff_sol.push_back(it->second); - master_smte->defineFunction(ff.toExpr(), args, body); - } - } - - // apply substitution to everything, should result in SAT - for (unsigned i = 0, size = assertions.size(); i < size; i++) - { - Node prev = assertions[i]; - Node curr = assertions[i].substitute(final_ff.begin(), - final_ff.end(), - final_ff_sol.begin(), - final_ff_sol.end()); - if (curr != prev) - { - curr = Rewriter::rewrite(curr); - Trace("sygus-infer-debug") - << "...rewrote " << prev << " to " << curr << std::endl; - assertions[i] = curr; + Node body = Node::fromExpr(it->second); + Trace("sygus-infer") << "Define " << ff << " as " << body << std::endl; + funs.push_back(ff); + sols.push_back(body); } } return true; } -} /* CVC4::theory::quantifiers namespace */ -} /* CVC4::theory namespace */ -} /* CVC4 namespace */ +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 diff --git a/src/theory/quantifiers/sygus_inference.h b/src/preprocessing/passes/sygus_inference.h similarity index 51% rename from src/theory/quantifiers/sygus_inference.h rename to src/preprocessing/passes/sygus_inference.h index 414103fc7..5e7c6f7d0 100644 --- a/src/theory/quantifiers/sygus_inference.h +++ b/src/preprocessing/passes/sygus_inference.h @@ -9,20 +9,23 @@ ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief sygus_inference + ** \brief SygusInference **/ -#include "cvc4_private.h" - -#ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H -#define __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H +#ifndef __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ +#define __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ +#include +#include #include #include "expr/node.h" +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" + namespace CVC4 { -namespace theory { -namespace quantifiers { +namespace preprocessing { +namespace passes { /** SygusInference * @@ -33,25 +36,36 @@ namespace quantifiers { * problem, thus obtaining a set of model substitutions under which the * assertions should simplify to true. */ -class SygusInference +class SygusInference : public PreprocessingPass { public: - SygusInference(); - ~SygusInference() {} - /** simplify assertions - * + SygusInference(PreprocessingPassContext* preprocContext); + + protected: + /** * Either replaces all uninterpreted functions in assertions by their - * interpretation in the solution found by a separate call to an SMT engine - * and returns true, or leaves the assertions unmodified and returns false. + * interpretation in a sygus solution, or leaves the assertions unmodified. + */ + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + /** solve sygus + * + * Returns true if we can recast the input problem assertions as a sygus + * problem and successfully solve it using a separate call to an SMT engine. * * We fail if either a sygus conjecture that corresponds to assertions cannot * be inferred, or the sygus conjecture we infer is infeasible. + * + * If this function returns true, then we add all uninterpreted symbols s in + * assertions to funs and their corresponding solution to sols. */ - bool simplify(std::vector& assertions); + bool solveSygus(std::vector& assertions, + std::vector& funs, + std::vector& sols); }; -} /* CVC4::theory::quantifiers namespace */ -} /* CVC4::theory namespace */ -} /* CVC4 namespace */ +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 -#endif /* __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H */ +#endif /* __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 418028d09..1e8ae4033 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -88,6 +88,7 @@ #include "preprocessing/passes/sep_skolem_emp.h" #include "preprocessing/passes/sort_infer.h" #include "preprocessing/passes/static_learning.h" +#include "preprocessing/passes/sygus_inference.h" #include "preprocessing/passes/symmetry_breaker.h" #include "preprocessing/passes/symmetry_detect.h" #include "preprocessing/passes/synth_rew_rules.h" @@ -119,7 +120,6 @@ #include "theory/quantifiers/quantifiers_rewriter.h" #include "theory/quantifiers/single_inv_partition.h" #include "theory/quantifiers/sygus/ce_guided_instantiation.h" -#include "theory/quantifiers/sygus_inference.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" #include "theory/sort_inference.h" @@ -2676,6 +2676,8 @@ void SmtEnginePrivate::finishInit() d_smt.d_theoryEngine->getSortInference())); std::unique_ptr staticLearning( new StaticLearning(d_preprocessingPassContext.get())); + std::unique_ptr sygusInfer( + new SygusInference(d_preprocessingPassContext.get())); std::unique_ptr sbProc( new SymBreakerPass(d_preprocessingPassContext.get())); std::unique_ptr srrProc( @@ -2713,6 +2715,8 @@ void SmtEnginePrivate::finishInit() std::move(sortInfer)); d_preprocessingPassRegistry.registerPass("static-learning", std::move(staticLearning)); + d_preprocessingPassRegistry.registerPass("sygus-infer", + std::move(sygusInfer)); d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc)); d_preprocessingPassRegistry.registerPass("synth-rr", std::move(srrProc)); } @@ -4243,12 +4247,7 @@ void SmtEnginePrivate::processAssertions() { } if (options::sygusInference()) { - // try recast as sygus - quantifiers::SygusInference si; - if (si.simplify(d_assertions.ref())) - { - Trace("smt-proc") << "...converted to sygus conjecture." << std::endl; - } + d_preprocessingPassRegistry.getPass("sygus-infer")->apply(&d_assertions); } Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : post-quant-preprocess" << endl; } -- 2.30.2