From 57dc21b3ee7ba8b991712361fc8804e1e584e674 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 14 Sep 2021 18:29:22 -0500 Subject: [PATCH] Add get-difficulty to the API (#7194) Adds smt2 parsing, printing and API support for get-difficulty. Adds unit tests. --- src/api/cpp/cvc5.cpp | 24 +++++++++- src/api/cpp/cvc5.h | 10 ++++ src/main/command_executor.cpp | 6 +++ src/options/main_options.toml | 8 ++++ src/parser/smt2/Smt2.g | 6 ++- src/printer/printer.cpp | 5 ++ src/printer/printer.h | 3 ++ src/printer/smt2/smt2_printer.cpp | 5 ++ src/printer/smt2/smt2_printer.h | 3 ++ src/prop/minisat/core/Solver.cc | 3 +- src/smt/command.cpp | 80 +++++++++++++++++++++++++++++++ src/smt/command.h | 23 +++++++++ src/smt/set_defaults.cpp | 4 ++ src/smt/smt_engine.cpp | 23 +++++++++ src/smt/smt_engine.h | 6 +++ test/unit/api/solver_black.cpp | 36 ++++++++++++++ 16 files changed, 242 insertions(+), 3 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 03c464e08..6aa556ed8 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -7294,12 +7294,34 @@ std::vector Solver::getUnsatCore(void) const CVC5_API_TRY_CATCH_END; } +std::map Solver::getDifficulty() const +{ + CVC5_API_TRY_CATCH_BEGIN; + NodeManagerScope scope(getNodeManager()); + CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getSmtMode() == SmtMode::UNSAT + || d_smtEngine->getSmtMode() == SmtMode::SAT + || d_smtEngine->getSmtMode() + == SmtMode::SAT_UNKNOWN) + << "Cannot get difficulty unless after a UNSAT, SAT or unknown response."; + //////// all checks before this line + std::map res; + std::map dmap; + d_smtEngine->getDifficultyMap(dmap); + for (const std::pair& d : dmap) + { + res[Term(this, d.first)] = Term(this, d.second); + } + return res; + //////// + CVC5_API_TRY_CATCH_END; +} + std::string Solver::getProof(void) const { CVC5_API_TRY_CATCH_BEGIN; NodeManagerScope scope(getNodeManager()); CVC5_API_CHECK(d_smtEngine->getOptions().smt.produceProofs) - << "Cannot get proof explicitly enabled (try --prooduce-proofs)"; + << "Cannot get proof explicitly enabled (try --produce-proofs)"; CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getSmtMode() == SmtMode::UNSAT) << "Cannot get proof unless in unsat mode."; return d_smtEngine->getProof(); diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index 684b89114..4f1c60814 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -3895,6 +3895,16 @@ class CVC5_EXPORT Solver */ std::vector getUnsatCore() const; + /** + * Get a difficulty estimate for an asserted formula. This method is + * intended to be called immediately after any response to a checkSat. + * + * @return a map from (a subset of) the input assertions to a real value that + * is an estimate of how difficult each assertion was to solve. Unmentioned + * assertions can be assumed to have zero difficulty. + */ + std::map getDifficulty() const; + /** * Get the refutation proof * SMT-LIB: diff --git a/src/main/command_executor.cpp b/src/main/command_executor.cpp index 6b03f82e3..2c5741491 100644 --- a/src/main/command_executor.cpp +++ b/src/main/command_executor.cpp @@ -169,6 +169,12 @@ bool CommandExecutor::doCommandSingleton(Command* cmd) getterCommands.emplace_back(new GetUnsatCoreCommand()); } + if (d_solver->getOptionInfo("dump-difficulty").boolValue() + && (isResultUnsat || isResultSat || res.isSatUnknown())) + { + getterCommands.emplace_back(new GetDifficultyCommand()); + } + if (!getterCommands.empty()) { // set no time limit during dumping if applicable if (d_solver->getOptionInfo("force-no-limit-cpu-while-dump").boolValue()) diff --git a/src/options/main_options.toml b/src/options/main_options.toml index fc03de167..55bea41f0 100644 --- a/src/options/main_options.toml +++ b/src/options/main_options.toml @@ -134,6 +134,14 @@ name = "Driver" default = "false" help = "dump the full unsat core, including unlabeled assertions" +[[option]] + name = "dumpDifficulty" + category = "regular" + long = "dump-difficulty" + type = "bool" + default = "false" + help = "dump the difficulty measure after every response to check-sat" + [[option]] name = "forceNoLimitCpuWhileDump" category = "regular" diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 2c391169c..8f9e01c6c 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -408,6 +408,9 @@ command [std::unique_ptr* cmd] | /* get-unsat-core */ GET_UNSAT_CORE_TOK { PARSER_STATE->checkThatLogicIsSet(); } { cmd->reset(new GetUnsatCoreCommand); } + | /* get-difficulty */ + GET_DIFFICULTY_TOK { PARSER_STATE->checkThatLogicIsSet(); } + { cmd->reset(new GetDifficultyCommand); } | /* push */ PUSH_TOK { PARSER_STATE->checkThatLogicIsSet(); } { if( PARSER_STATE->sygus() ){ @@ -1253,7 +1256,7 @@ simpleSymbolicExprNoKeyword[std::string& s] | DEFINE_FUN_TOK | DEFINE_FUN_REC_TOK | DEFINE_FUNS_REC_TOK | DEFINE_SORT_TOK | GET_VALUE_TOK | GET_ASSIGNMENT_TOK | GET_ASSERTIONS_TOK | GET_PROOF_TOK | GET_UNSAT_ASSUMPTIONS_TOK - | GET_UNSAT_CORE_TOK | EXIT_TOK + | GET_UNSAT_CORE_TOK | GET_DIFFICULTY_TOK | EXIT_TOK | RESET_TOK | RESET_ASSERTIONS_TOK | SET_LOGIC_TOK | SET_INFO_TOK | GET_INFO_TOK | SET_OPTION_TOK | GET_OPTION_TOK | PUSH_TOK | POP_TOK | DECLARE_DATATYPES_TOK | GET_MODEL_TOK | ECHO_TOK | SIMPLIFY_TOK) @@ -2222,6 +2225,7 @@ GET_ASSERTIONS_TOK : 'get-assertions'; GET_PROOF_TOK : 'get-proof'; GET_UNSAT_ASSUMPTIONS_TOK : 'get-unsat-assumptions'; GET_UNSAT_CORE_TOK : 'get-unsat-core'; +GET_DIFFICULTY_TOK : 'get-difficulty'; EXIT_TOK : 'exit'; RESET_TOK : 'reset'; RESET_ASSERTIONS_TOK : 'reset-assertions'; diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 01fa7a9fd..f1ad9212f 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -430,6 +430,11 @@ void Printer::toStreamCmdGetUnsatCore(std::ostream& out) const printUnknownCommand(out, "get-unsat-core"); } +void Printer::toStreamCmdGetDifficulty(std::ostream& out) const +{ + printUnknownCommand(out, "get-difficulty"); +} + void Printer::toStreamCmdGetAssertions(std::ostream& out) const { printUnknownCommand(out, "get-assertions"); diff --git a/src/printer/printer.h b/src/printer/printer.h index 499a9398f..b657a6dfa 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -213,6 +213,9 @@ class Printer /** Print get-unsat-core command */ virtual void toStreamCmdGetUnsatCore(std::ostream& out) const; + /** Print get-difficulty command */ + virtual void toStreamCmdGetDifficulty(std::ostream& out) const; + /** Print get-assertions command */ virtual void toStreamCmdGetAssertions(std::ostream& out) const; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 07c5b10d8..b92d8fa43 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1612,6 +1612,11 @@ void Smt2Printer::toStreamCmdGetUnsatCore(std::ostream& out) const out << "(get-unsat-core)" << std::endl; } +void Smt2Printer::toStreamCmdGetDifficulty(std::ostream& out) const +{ + out << "(get-difficulty)" << std::endl; +} + void Smt2Printer::toStreamCmdSetBenchmarkStatus(std::ostream& out, Result::Sat status) const { diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index fd7e0c7ac..839c016af 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -165,6 +165,9 @@ class Smt2Printer : public cvc5::Printer /** Print get-unsat-core command */ void toStreamCmdGetUnsatCore(std::ostream& out) const override; + /** Print get-difficulty command */ + void toStreamCmdGetDifficulty(std::ostream& out) const override; + /** Print get-assertions command */ void toStreamCmdGetAssertions(std::ostream& out) const override; diff --git a/src/prop/minisat/core/Solver.cc b/src/prop/minisat/core/Solver.cc index 6f99a47f0..3a2a79ddd 100644 --- a/src/prop/minisat/core/Solver.cc +++ b/src/prop/minisat/core/Solver.cc @@ -2188,7 +2188,8 @@ bool Solver::isProofEnabled() const { return d_pfManager != nullptr; } bool Solver::needProof() const { return isProofEnabled() - && options::unsatCoresMode() != options::UnsatCoresMode::ASSUMPTIONS; + && options::unsatCoresMode() != options::UnsatCoresMode::ASSUMPTIONS + && options::unsatCoresMode() != options::UnsatCoresMode::PP_ONLY; } } // namespace Minisat diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 4b04abcb2..34b2c9692 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -2411,6 +2411,86 @@ void GetUnsatCoreCommand::toStream(std::ostream& out, Printer::getPrinter(language)->toStreamCmdGetUnsatCore(out); } +/* -------------------------------------------------------------------------- */ +/* class GetDifficultyCommand */ +/* -------------------------------------------------------------------------- */ + +GetDifficultyCommand::GetDifficultyCommand() : d_sm(nullptr) {} +void GetDifficultyCommand::invoke(api::Solver* solver, SymbolManager* sm) +{ + try + { + d_sm = sm; + d_result = solver->getDifficulty(); + + d_commandStatus = CommandSuccess::instance(); + } + catch (api::CVC5ApiRecoverableException& e) + { + d_commandStatus = new CommandRecoverableFailure(e.what()); + } + catch (exception& e) + { + d_commandStatus = new CommandFailure(e.what()); + } +} + +void GetDifficultyCommand::printResult(std::ostream& out, + uint32_t verbosity) const +{ + if (!ok()) + { + this->Command::printResult(out, verbosity); + } + else + { + out << "(" << std::endl; + for (const std::pair& d : d_result) + { + out << "("; + // use name if it has one + std::string name; + if (d_sm->getExpressionName(d.first, name, true)) + { + out << name; + } + else + { + out << d.first; + } + out << " " << d.second << ")" << std::endl; + } + out << ")" << std::endl; + } +} + +const std::map& GetDifficultyCommand::getDifficultyMap() + const +{ + return d_result; +} + +Command* GetDifficultyCommand::clone() const +{ + GetDifficultyCommand* c = new GetDifficultyCommand; + c->d_sm = d_sm; + c->d_result = d_result; + return c; +} + +std::string GetDifficultyCommand::getCommandName() const +{ + return "get-difficulty"; +} + +void GetDifficultyCommand::toStream(std::ostream& out, + int toDepth, + size_t dag, + Language language) const +{ + Printer::getPrinter(language)->toStreamCmdGetDifficulty(out); +} + /* -------------------------------------------------------------------------- */ /* class GetAssertionsCommand */ /* -------------------------------------------------------------------------- */ diff --git a/src/smt/command.h b/src/smt/command.h index 7587aaa63..989e38ef0 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -1190,6 +1190,29 @@ class CVC5_EXPORT GetUnsatCoreCommand : public Command std::vector d_result; }; /* class GetUnsatCoreCommand */ +class CVC5_EXPORT GetDifficultyCommand : public Command +{ + public: + GetDifficultyCommand(); + const std::map& getDifficultyMap() const; + + void invoke(api::Solver* solver, SymbolManager* sm) override; + void printResult(std::ostream& out, uint32_t verbosity = 2) const override; + + Command* clone() const override; + std::string getCommandName() const override; + void toStream(std::ostream& out, + int toDepth = -1, + size_t dag = 1, + Language language = Language::LANG_AUTO) const override; + + protected: + /** The symbol manager we were invoked with */ + SymbolManager* d_sm; + /** the result of the get difficulty call */ + std::map d_result; +}; + class CVC5_EXPORT GetAssertionsCommand : public Command { protected: diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp index a226de807..4d5e898fe 100644 --- a/src/smt/set_defaults.cpp +++ b/src/smt/set_defaults.cpp @@ -81,6 +81,10 @@ void SetDefaults::setDefaultsPre(Options& opts) { opts.driver.dumpUnsatCores = true; } + if (opts.driver.dumpDifficulty) + { + opts.smt.produceDifficulty = true; + } if (opts.smt.produceDifficulty) { if (opts.smt.unsatCoresMode == options::UnsatCoresMode::OFF) diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 46e83e9e7..aa7534f26 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -1774,6 +1774,29 @@ std::vector SmtEngine::getAssertions() return getAssertionsInternal(); } +void SmtEngine::getDifficultyMap(std::map& dmap) +{ + Trace("smt") << "SMT getDifficultyMap()\n"; + SmtScope smts(this); + finishInit(); + if (Dump.isOn("benchmark")) + { + getPrinter().toStreamCmdGetDifficulty(d_env->getDumpOut()); + } + if (!d_env->getOptions().smt.produceDifficulty) + { + throw ModalException( + "Cannot get difficulty when difficulty option is off."); + } + // the prop engine has the proof of false + Assert(d_pfManager); + // get difficulty map from theory engine first + TheoryEngine* te = getTheoryEngine(); + te->getDifficultyMap(dmap); + // then ask proof manager to translate dmap in terms of the input + d_pfManager->translateDifficultyMap(dmap, *d_asserts); +} + void SmtEngine::push() { SmtScope smts(this); diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 9f17fa27e..a06b2fd61 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -701,6 +701,12 @@ class CVC5_EXPORT SmtEngine */ std::vector getAssertions(); + /** + * Get difficulty map, which populates dmap, mapping input assertions + * to a value that estimates their difficulty for solving the current problem. + */ + void getDifficultyMap(std::map& dmap); + /** * Push a user-level context. * throw@ ModalException, LogicException, UnsafeInterruptException diff --git a/test/unit/api/solver_black.cpp b/test/unit/api/solver_black.cpp index 9042209e2..fe8cf87ba 100644 --- a/test/unit/api/solver_black.cpp +++ b/test/unit/api/solver_black.cpp @@ -1467,6 +1467,42 @@ TEST_F(TestApiBlackSolver, getUnsatCoreAndProof) ASSERT_NO_THROW(d_solver.getProof()); } +TEST_F(TestApiBlackSolver, getDifficulty) +{ + d_solver.setOption("produce-difficulty", "true"); + // cannot ask before a check sat + ASSERT_THROW(d_solver.getDifficulty(), CVC5ApiException); + d_solver.checkSat(); + ASSERT_NO_THROW(d_solver.getDifficulty()); +} + +TEST_F(TestApiBlackSolver, getDifficulty2) +{ + d_solver.checkSat(); + // option is not set + ASSERT_THROW(d_solver.getDifficulty(), CVC5ApiException); +} + +TEST_F(TestApiBlackSolver, getDifficulty3) +{ + d_solver.setOption("produce-difficulty", "true"); + Sort intSort = d_solver.getIntegerSort(); + Term x = d_solver.mkConst(intSort, "x"); + Term zero = d_solver.mkInteger(0); + Term ten = d_solver.mkInteger(10); + Term f0 = d_solver.mkTerm(GEQ, x, ten); + Term f1 = d_solver.mkTerm(GEQ, zero, x); + d_solver.checkSat(); + std::map dmap; + ASSERT_NO_THROW(dmap = d_solver.getDifficulty()); + // difficulty should map assertions to integer values + for (const std::pair& t : dmap) + { + ASSERT_TRUE(t.first == f0 || t.first == f1); + ASSERT_TRUE(t.second.getKind() == CONST_RATIONAL); + } +} + TEST_F(TestApiBlackSolver, getValue1) { d_solver.setOption("produce-models", "false"); -- 2.30.2