From 8bd0cb485ecdc405102e7cbace37baf5c687f1b8 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 4 Mar 2022 08:37:54 -0600 Subject: [PATCH] Add support for get learned literals in the API (#8099) This command will eventually take a mode; for now it assumes a default implementation. I've opened cvc5/cvc5-wishues#104 to track this. This is a feature requested by Certora. --- src/api/cpp/cvc5.cpp | 36 ++++++++-- src/api/cpp/cvc5.h | 11 ++++ src/api/java/io/github/cvc5/api/Solver.java | 15 +++++ src/api/java/jni/solver.cpp | 16 +++++ src/api/python/cvc5.pxd | 1 + src/api/python/cvc5.pxi | 18 +++++ src/options/smt_options.toml | 8 +++ src/parser/smt2/Smt2.g | 4 ++ src/printer/printer.cpp | 5 ++ src/printer/printer.h | 3 + src/printer/smt2/smt2_printer.cpp | 5 ++ src/printer/smt2/smt2_printer.h | 3 + src/prop/prop_engine.cpp | 2 +- src/prop/prop_engine.h | 2 +- src/prop/theory_proxy.cpp | 12 ++-- src/prop/theory_proxy.h | 2 +- src/prop/zero_level_learner.cpp | 17 +++-- src/prop/zero_level_learner.h | 12 ++-- src/smt/command.cpp | 66 +++++++++++++++++++ src/smt/command.h | 21 ++++++ src/smt/solver_engine.cpp | 11 ++++ src/smt/solver_engine.h | 8 +++ test/regress/CMakeLists.txt | 1 + .../regress1/get-learned-literals.smt2 | 17 +++++ test/unit/api/cpp/solver_black.cpp | 28 ++++++++ test/unit/api/java/SolverTest.java | 28 ++++++++ test/unit/api/python/test_solver.py | 21 ++++++ 27 files changed, 349 insertions(+), 24 deletions(-) create mode 100644 test/regress/regress1/get-learned-literals.smt2 diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 51f8b911d..5ffb23e9d 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -3013,6 +3013,17 @@ std::vector Term::termVectorToNodes(const std::vector& terms) return res; } +std::vector Term::nodeVectorToTerms(const Solver* slv, + const std::vector& nodes) +{ + std::vector res; + for (const Node& n : nodes) + { + res.push_back(Term(slv, n)); + } + return res; +} + bool Term::isReal32Value() const { CVC5_API_TRY_CATCH_BEGIN; @@ -6997,12 +7008,7 @@ std::vector Solver::getAssertions(void) const /* Can not use * return std::vector(assertions.begin(), assertions.end()); * here since constructor is private */ - std::vector res; - for (const Node& e : assertions) - { - res.push_back(Term(this, e)); - } - return res; + return Term::nodeVectorToTerms(this, assertions); //////// CVC5_API_TRY_CATCH_END; } @@ -7306,6 +7312,24 @@ std::string Solver::getProof(void) const CVC5_API_TRY_CATCH_END; } +std::vector Solver::getLearnedLiterals(void) const +{ + CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_CHECK(d_slv->getOptions().smt.produceLearnedLiterals) + << "Cannot get learned literals unless enabled (try " + "--produce-learned-literals)"; + CVC5_API_RECOVERABLE_CHECK(d_slv->getSmtMode() == SmtMode::UNSAT + || d_slv->getSmtMode() == SmtMode::SAT + || d_slv->getSmtMode() == SmtMode::SAT_UNKNOWN) + << "Cannot get learned literals unless after a UNSAT, SAT or UNKNOWN " + "response."; + //////// all checks before this line + std::vector lits = d_slv->getLearnedLiterals(); + return Term::nodeVectorToTerms(this, lits); + //////// + CVC5_API_TRY_CATCH_END; +} + Term Solver::getValue(const Term& term) const { CVC5_API_TRY_CATCH_BEGIN; diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index e75bef293..d3662c832 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -1585,6 +1585,9 @@ class CVC5_EXPORT Term private: /** Helper to convert a vector of Terms to internal Nodes. */ std::vector static termVectorToNodes(const std::vector& terms); + /** Helper to convert a vector of internal Nodes to Terms. */ + std::vector static nodeVectorToTerms(const Solver* slv, + const std::vector& nodes); /** Helper method to collect all elements of a set. */ static void collectSet(std::set& set, @@ -4181,6 +4184,14 @@ class CVC5_EXPORT Solver */ std::string getProof() const; + /** + * Get learned literals + * + * @return a list of literals that were learned at top-level. In other words, + * these are literals that are entailed by the current set of assertions. + */ + std::vector getLearnedLiterals() const; + /** * Get the value of the given term in the current model. * diff --git a/src/api/java/io/github/cvc5/api/Solver.java b/src/api/java/io/github/cvc5/api/Solver.java index 43493f784..f1ad0cea6 100644 --- a/src/api/java/io/github/cvc5/api/Solver.java +++ b/src/api/java/io/github/cvc5/api/Solver.java @@ -1760,6 +1760,21 @@ public class Solver implements IPointer, AutoCloseable */ // TODO: void echo(std::ostream& out, String str) + /** + * Get a list of literals that are entailed by the current set of assertions + * SMT-LIB: + * {@code + * ( get-learned-literals ) + * } + * @return the list of learned literals + */ + public Term[] getLearnedLiterals() { + long[] retPointers = getLearnedLiterals(pointer); + return Utils.getTerms(this, retPointers); + } + + private native long[] getLearnedLiterals(long pointer); + /** * Get the list of asserted formulas. * SMT-LIB: diff --git a/src/api/java/jni/solver.cpp b/src/api/java/jni/solver.cpp index ec54b09b7..9af316cbf 100644 --- a/src/api/java/jni/solver.cpp +++ b/src/api/java/jni/solver.cpp @@ -1783,6 +1783,22 @@ Java_io_github_cvc5_api_Solver_defineFunsRec(JNIEnv* env, CVC5_JAVA_API_TRY_CATCH_END(env); } +/* + * Class: io_github_cvc5_api_Solver + * Method: getLearnedLiterals + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL Java_io_github_cvc5_api_Solver_getLearnedLiterals( + JNIEnv* env, jobject, jlong pointer) +{ + CVC5_JAVA_API_TRY_CATCH_BEGIN; + Solver* solver = reinterpret_cast(pointer); + std::vector assertions = solver->getLearnedLiterals(); + jlongArray ret = getPointersFromObjects(env, assertions); + return ret; + CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, nullptr); +} + /* * Class: io_github_cvc5_api_Solver * Method: getAssertions diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index 357d5d149..095028365 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -273,6 +273,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Term term, bint glbl) except + Term defineFunsRec(vector[Term]& funs, vector[vector[Term]]& bound_vars, vector[Term]& terms, bint glbl) except + + vector[Term] getLearnedLiterals() except + vector[Term] getAssertions() except + string getInfo(const string& flag) except + string getOption(string& option) except + diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index 03bb2c9d3..d843d55ed 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -1944,6 +1944,24 @@ cdef class Solver: for t in terms: vf.push_back(( t).cterm) + def getLearnedLiterals(self): + """Get a list of literals that are entailed by the current set of assertions + + SMT-LIB: + + .. code-block:: smtlib + + ( get-learned-literals ) + + :return: the list of literals + """ + lits = [] + for a in self.csolver.getLearnedLiterals(): + term = Term(self) + term.cterm = a + lits.append(term) + return lits + def getAssertions(self): """Get the list of asserted formulas. diff --git a/src/options/smt_options.toml b/src/options/smt_options.toml index e37d84d1a..8e961a3bd 100644 --- a/src/options/smt_options.toml +++ b/src/options/smt_options.toml @@ -116,6 +116,14 @@ name = "SMT Layer" name = "values" help = "Block models based on the concrete model values for the free variables." +[[option]] + name = "produceLearnedLiterals" + category = "regular" + long = "produce-learned-literals" + type = "bool" + default = "false" + help = "produce learned literals, support get-learned-literals" + [[option]] name = "produceProofs" category = "regular" diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 355a3469e..ae935b7aa 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -404,6 +404,9 @@ command [std::unique_ptr* cmd] | /* get-difficulty */ GET_DIFFICULTY_TOK { PARSER_STATE->checkThatLogicIsSet(); } { cmd->reset(new GetDifficultyCommand); } + | /* get-learned-literals */ + GET_LEARNED_LITERALS_TOK { PARSER_STATE->checkThatLogicIsSet(); } + { cmd->reset(new GetLearnedLiteralsCommand); } | /* push */ PUSH_TOK { PARSER_STATE->checkThatLogicIsSet(); } ( k=INTEGER_LITERAL @@ -2204,6 +2207,7 @@ GET_PROOF_TOK : 'get-proof'; GET_UNSAT_ASSUMPTIONS_TOK : 'get-unsat-assumptions'; GET_UNSAT_CORE_TOK : 'get-unsat-core'; GET_DIFFICULTY_TOK : 'get-difficulty'; +GET_LEARNED_LITERALS_TOK : { !PARSER_STATE->strictModeEnabled() }? 'get-learned-literals'; EXIT_TOK : 'exit'; RESET_TOK : 'reset'; RESET_ASSERTIONS_TOK : 'reset-assertions'; diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 12b52e284..f3909ae3f 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -445,6 +445,11 @@ void Printer::toStreamCmdGetDifficulty(std::ostream& out) const printUnknownCommand(out, "get-difficulty"); } +void Printer::toStreamCmdGetLearnedLiterals(std::ostream& out) const +{ + printUnknownCommand(out, "get-learned-literals"); +} + void Printer::toStreamCmdGetAssertions(std::ostream& out) const { printUnknownCommand(out, "get-assertions"); diff --git a/src/printer/printer.h b/src/printer/printer.h index 2a9283001..33824d42d 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -226,6 +226,9 @@ class Printer /** Print get-difficulty command */ virtual void toStreamCmdGetDifficulty(std::ostream& out) const; + /** Print get-learned-literals command */ + virtual void toStreamCmdGetLearnedLiterals(std::ostream& out) const; + /** Print get-assertions command */ virtual void toStreamCmdGetAssertions(std::ostream& out) const; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 773d5fd4b..80bf2b601 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1726,6 +1726,11 @@ void Smt2Printer::toStreamCmdGetDifficulty(std::ostream& out) const out << "(get-difficulty)" << std::endl; } +void Smt2Printer::toStreamCmdGetLearnedLiterals(std::ostream& out) const +{ + out << "(get-learned-literals)" << std::endl; +} + void Smt2Printer::toStreamCmdSetBenchmarkLogic(std::ostream& out, const std::string& logic) const { diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 4170fe88f..6f78f9668 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -200,6 +200,9 @@ class Smt2Printer : public cvc5::Printer /** Print get-difficulty command */ void toStreamCmdGetDifficulty(std::ostream& out) const override; + /** Print get-learned-literals command */ + void toStreamCmdGetLearnedLiterals(std::ostream& out) const override; + /** Print get-assertions command */ void toStreamCmdGetAssertions(std::ostream& out) const override; diff --git a/src/prop/prop_engine.cpp b/src/prop/prop_engine.cpp index bf9a80008..734da0b76 100644 --- a/src/prop/prop_engine.cpp +++ b/src/prop/prop_engine.cpp @@ -672,7 +672,7 @@ std::shared_ptr PropEngine::getRefutation() return cdp.getProofFor(fnode); } -const std::unordered_set& PropEngine::getLearnedZeroLevelLiterals() const +std::vector PropEngine::getLearnedZeroLevelLiterals() const { return d_theoryProxy->getLearnedZeroLevelLiterals(); } diff --git a/src/prop/prop_engine.h b/src/prop/prop_engine.h index 96cd2ce1d..fc3c68ca5 100644 --- a/src/prop/prop_engine.h +++ b/src/prop/prop_engine.h @@ -295,7 +295,7 @@ class PropEngine : protected EnvObj std::shared_ptr getRefutation(); /** Get the zero-level assertions */ - const std::unordered_set& getLearnedZeroLevelLiterals() const; + std::vector getLearnedZeroLevelLiterals() const; private: /** Dump out the satisfying assignment (after SAT result) */ diff --git a/src/prop/theory_proxy.cpp b/src/prop/theory_proxy.cpp index 357248726..44bc79d43 100644 --- a/src/prop/theory_proxy.cpp +++ b/src/prop/theory_proxy.cpp @@ -52,7 +52,8 @@ TheoryProxy::TheoryProxy(Env& env, d_skdm(skdm), d_zll(nullptr) { - bool trackTopLevelLearned = isOutputOn(OutputTag::LEARNED_LITS); + bool trackTopLevelLearned = isOutputOn(OutputTag::LEARNED_LITS) + || options().smt.produceLearnedLiterals; if (trackTopLevelLearned) { d_zll = std::make_unique(env, propEngine); @@ -289,10 +290,13 @@ void TheoryProxy::getSkolems(TNode node, void TheoryProxy::preRegister(Node n) { d_theoryEngine->preRegister(n); } -const std::unordered_set& TheoryProxy::getLearnedZeroLevelLiterals() const +std::vector TheoryProxy::getLearnedZeroLevelLiterals() const { - Assert(d_zll != nullptr); - return d_zll->getLearnedZeroLevelLiterals(); + if (d_zll != nullptr) + { + return d_zll->getLearnedZeroLevelLiterals(); + } + return {}; } } // namespace prop diff --git a/src/prop/theory_proxy.h b/src/prop/theory_proxy.h index 85070b4c7..3e95f4a03 100644 --- a/src/prop/theory_proxy.h +++ b/src/prop/theory_proxy.h @@ -154,7 +154,7 @@ class TheoryProxy : protected EnvObj, public Registrar void preRegister(Node n) override; /** Get the zero-level assertions */ - const std::unordered_set& getLearnedZeroLevelLiterals() const; + std::vector getLearnedZeroLevelLiterals() const; private: /** The prop engine we are using. */ diff --git a/src/prop/zero_level_learner.cpp b/src/prop/zero_level_learner.cpp index 3562545b8..b5415f714 100644 --- a/src/prop/zero_level_learner.cpp +++ b/src/prop/zero_level_learner.cpp @@ -30,7 +30,11 @@ namespace prop { ZeroLevelLearner::ZeroLevelLearner(Env& env, PropEngine* propEngine) : EnvObj(env), d_propEngine(propEngine), + d_levelZeroAsserts(userContext()), + d_levelZeroAssertsLearned(userContext()), d_nonZeroAssert(context(), false), + d_ppnAtoms(userContext()), + d_pplAtoms(userContext()), d_assertNoLearnCount(0) { } @@ -39,7 +43,7 @@ ZeroLevelLearner::~ZeroLevelLearner() {} void ZeroLevelLearner::getAtoms(TNode a, std::unordered_set& visited, - std::unordered_set& ppLits) + NodeSet& ppLits) { std::vector visit; TNode cur; @@ -67,7 +71,6 @@ void ZeroLevelLearner::notifyInputFormulas( const std::vector& ppl) { d_assertNoLearnCount = 0; - d_ppnAtoms.clear(); // Copy the preprocessed assertions and skolem map information directly // Also, compute the set of literals in the preprocessed assertions std::unordered_set visited; @@ -146,10 +149,14 @@ void ZeroLevelLearner::notifyAsserted(TNode assertion) } } -const std::unordered_set& ZeroLevelLearner::getLearnedZeroLevelLiterals() - const +std::vector ZeroLevelLearner::getLearnedZeroLevelLiterals() const { - return d_levelZeroAssertsLearned; + std::vector ret; + for (const Node& n : d_levelZeroAssertsLearned) + { + ret.push_back(n); + } + return ret; } } // namespace prop diff --git a/src/prop/zero_level_learner.h b/src/prop/zero_level_learner.h index e9979e723..c15e04e98 100644 --- a/src/prop/zero_level_learner.h +++ b/src/prop/zero_level_learner.h @@ -58,30 +58,30 @@ class ZeroLevelLearner : protected EnvObj void notifyAsserted(TNode assertion); /** Get the zero-level assertions */ - const std::unordered_set& getLearnedZeroLevelLiterals() const; + std::vector getLearnedZeroLevelLiterals() const; private: static void getAtoms(TNode a, std::unordered_set& visited, - std::unordered_set& ppLits); + NodeSet& ppLits); /** The prop engine we are using. */ PropEngine* d_propEngine; /** Set of literals that hold at level 0 */ - std::unordered_set d_levelZeroAsserts; + NodeSet d_levelZeroAsserts; /** Set of learnable literals that hold at level 0 */ - std::unordered_set d_levelZeroAssertsLearned; + NodeSet d_levelZeroAssertsLearned; /** Whether we have seen an assertion level > 0 */ context::CDO d_nonZeroAssert; /** Preprocessed literals that are not learned */ - std::unordered_set d_ppnAtoms; + NodeSet d_ppnAtoms; /** Already learned TEMPORARY */ - std::unordered_set d_pplAtoms; + NodeSet d_pplAtoms; /** Current counter of assertions */ size_t d_assertNoLearnCount; diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 7a81aa0a0..f96488452 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -2549,6 +2549,72 @@ void GetDifficultyCommand::toStream(std::ostream& out, Printer::getPrinter(language)->toStreamCmdGetDifficulty(out); } +/* -------------------------------------------------------------------------- */ +/* class GetLearnedLiteralsCommand */ +/* -------------------------------------------------------------------------- */ + +GetLearnedLiteralsCommand::GetLearnedLiteralsCommand() {} +void GetLearnedLiteralsCommand::invoke(api::Solver* solver, SymbolManager* sm) +{ + try + { + d_result = solver->getLearnedLiterals(); + + d_commandStatus = CommandSuccess::instance(); + } + catch (api::CVC5ApiRecoverableException& e) + { + d_commandStatus = new CommandRecoverableFailure(e.what()); + } + catch (exception& e) + { + d_commandStatus = new CommandFailure(e.what()); + } +} + +void GetLearnedLiteralsCommand::printResult(std::ostream& out) const +{ + if (!ok()) + { + this->Command::printResult(out); + } + else + { + out << "(" << std::endl; + for (const api::Term& lit : d_result) + { + out << lit << std::endl; + } + out << ")" << std::endl; + } +} + +const std::vector& GetLearnedLiteralsCommand::getLearnedLiterals() + const +{ + return d_result; +} + +Command* GetLearnedLiteralsCommand::clone() const +{ + GetLearnedLiteralsCommand* c = new GetLearnedLiteralsCommand; + c->d_result = d_result; + return c; +} + +std::string GetLearnedLiteralsCommand::getCommandName() const +{ + return "get-learned-literals"; +} + +void GetLearnedLiteralsCommand::toStream(std::ostream& out, + int toDepth, + size_t dag, + Language language) const +{ + Printer::getPrinter(language)->toStreamCmdGetLearnedLiterals(out); +} + /* -------------------------------------------------------------------------- */ /* class GetAssertionsCommand */ /* -------------------------------------------------------------------------- */ diff --git a/src/smt/command.h b/src/smt/command.h index 9a4b04196..b250d0b59 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -1254,6 +1254,27 @@ class CVC5_EXPORT GetDifficultyCommand : public Command std::map d_result; }; +class CVC5_EXPORT GetLearnedLiteralsCommand : public Command +{ + public: + GetLearnedLiteralsCommand(); + const std::vector& getLearnedLiterals() const; + + void invoke(api::Solver* solver, SymbolManager* sm) override; + void printResult(std::ostream& out) const override; + + Command* clone() const override; + std::string getCommandName() const override; + void toStream(std::ostream& out, + int toDepth = -1, + size_t dag = 1, + Language language = Language::LANG_AUTO) const override; + + protected: + /** the result of the get learned literals call */ + std::vector d_result; +}; + class CVC5_EXPORT GetAssertionsCommand : public Command { protected: diff --git a/src/smt/solver_engine.cpp b/src/smt/solver_engine.cpp index 1223e9875..583eaae7b 100644 --- a/src/smt/solver_engine.cpp +++ b/src/smt/solver_engine.cpp @@ -1272,6 +1272,17 @@ Node SolverEngine::getSepHeapExpr() { return getSepHeapAndNilExpr().first; } Node SolverEngine::getSepNilExpr() { return getSepHeapAndNilExpr().second; } +std::vector SolverEngine::getLearnedLiterals() +{ + Trace("smt") << "SMT getLearnedLiterals()" << std::endl; + SolverEngineScope smts(this); + // note that the default mode for learned literals is via the prop engine, + // although other modes could use the preprocessor + PropEngine* pe = getPropEngine(); + Assert(pe != nullptr); + return pe->getLearnedZeroLevelLiterals(); +} + void SolverEngine::checkProof() { Assert(d_env->getOptions().smt.produceProofs); diff --git a/src/smt/solver_engine.h b/src/smt/solver_engine.h index fd1057f3d..c6cd77dd7 100644 --- a/src/smt/solver_engine.h +++ b/src/smt/solver_engine.h @@ -271,6 +271,14 @@ class CVC5_EXPORT SolverEngine /** When using separation logic, obtain the expression for nil. */ Node getSepNilExpr(); + /** + * Get the list of top-level learned literals that are entailed by the current + * set of assertions. + * + * TODO (wishue #104): implement for different modes + */ + std::vector getLearnedLiterals(); + /** * Get an aspect of the current SMT execution environment. * @throw OptionException diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 0e1e84f1a..94e9d3c41 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1863,6 +1863,7 @@ set(regress_1_tests regress1/fmf/sort-inf-int.smt2 regress1/fmf/with-ind-104-core.smt2 regress1/gensys_brn001.smt2 + regress1/get-learned-literals.smt2 regress1/ho/bug_freeVar_BDD_General_data_270.p regress1/ho/bug_freevar_PHI004^4-delta.smt2 regress1/ho/bound_var_bug.p diff --git a/test/regress/regress1/get-learned-literals.smt2 b/test/regress/regress1/get-learned-literals.smt2 new file mode 100644 index 000000000..d52a931ae --- /dev/null +++ b/test/regress/regress1/get-learned-literals.smt2 @@ -0,0 +1,17 @@ +; SCRUBBER: sed -e 's/(>=.*/learned-ineq/' +; EXPECT: sat +; EXPECT: ( +; EXPECT: learned-ineq +; EXPECT: learned-ineq +; EXPECT: ) +(set-logic ALL) +(set-option :produce-learned-literals true) +(declare-fun x () Int) +(declare-fun y () Int) +(declare-fun z () Int) + +(assert (> x 5)) +(assert (< y 4)) +(assert (or (< x y) (> z 0))) +(check-sat) +(get-learned-literals) \ No newline at end of file diff --git a/test/unit/api/cpp/solver_black.cpp b/test/unit/api/cpp/solver_black.cpp index ac6e4543e..fa573f5da 100644 --- a/test/unit/api/cpp/solver_black.cpp +++ b/test/unit/api/cpp/solver_black.cpp @@ -1694,6 +1694,8 @@ TEST_F(TestApiBlackSolver, getDifficulty3) Term ten = d_solver.mkInteger(10); Term f0 = d_solver.mkTerm(GEQ, x, ten); Term f1 = d_solver.mkTerm(GEQ, zero, x); + d_solver.assertFormula(f0); + d_solver.assertFormula(f1); d_solver.checkSat(); std::map dmap; ASSERT_NO_THROW(dmap = d_solver.getDifficulty()); @@ -1705,6 +1707,32 @@ TEST_F(TestApiBlackSolver, getDifficulty3) } } +TEST_F(TestApiBlackSolver, getLearnedLiterals) +{ + d_solver.setOption("produce-learned-literals", "true"); + // cannot ask before a check sat + ASSERT_THROW(d_solver.getLearnedLiterals(), CVC5ApiException); + d_solver.checkSat(); + ASSERT_NO_THROW(d_solver.getLearnedLiterals()); +} + +TEST_F(TestApiBlackSolver, getLearnedLiterals2) +{ + d_solver.setOption("produce-learned-literals", "true"); + Sort intSort = d_solver.getIntegerSort(); + Term x = d_solver.mkConst(intSort, "x"); + Term y = d_solver.mkConst(intSort, "y"); + Term zero = d_solver.mkInteger(0); + Term ten = d_solver.mkInteger(10); + Term f0 = d_solver.mkTerm(GEQ, x, ten); + Term f1 = d_solver.mkTerm( + OR, d_solver.mkTerm(GEQ, zero, x), d_solver.mkTerm(GEQ, y, zero)); + d_solver.assertFormula(f0); + d_solver.assertFormula(f1); + d_solver.checkSat(); + ASSERT_NO_THROW(d_solver.getLearnedLiterals()); +} + TEST_F(TestApiBlackSolver, getValue1) { d_solver.setOption("produce-models", "false"); diff --git a/test/unit/api/java/SolverTest.java b/test/unit/api/java/SolverTest.java index 80921481a..6bcd38e92 100644 --- a/test/unit/api/java/SolverTest.java +++ b/test/unit/api/java/SolverTest.java @@ -1669,6 +1669,8 @@ class SolverTest Term ten = d_solver.mkInteger(10); Term f0 = d_solver.mkTerm(GEQ, x, ten); Term f1 = d_solver.mkTerm(GEQ, zero, x); + d_solver.assertFormula(f0); + d_solver.assertFormula(f1); d_solver.checkSat(); Map dmap = d_solver.getDifficulty(); // difficulty should map assertions to integer values @@ -1679,6 +1681,32 @@ class SolverTest } } + @Test + void getLearnedLiterals() { + d_solver.setOption("produce-learned-literals", "true"); + // cannot ask before a check sat + assertThrows(CVC5ApiException.class, () -> d_solver.getLearnedLiterals()); + d_solver.checkSat(); + assertDoesNotThrow(() -> d_solver.getLearnedLiterals()); + } + + @Test + void getLearnedLiterals2() { + d_solver.setOption("produce-learned-literals", "true"); + Sort intSort = d_solver.getIntegerSort(); + Term x = d_solver.mkConst(intSort, "x"); + Term y = d_solver.mkConst(intSort, "y"); + Term zero = d_solver.mkInteger(0); + Term ten = d_solver.mkInteger(10); + Term f0 = d_solver.mkTerm(GEQ, x, ten); + Term f1 = d_solver.mkTerm( + OR, d_solver.mkTerm(GEQ, zero, x), d_solver.mkTerm(GEQ, y, zero)); + d_solver.assertFormula(f0); + d_solver.assertFormula(f1); + d_solver.checkSat(); + assertDoesNotThrow(() -> d_solver.getLearnedLiterals()); + } + @Test void getValue1() { d_solver.setOption("produce-models", "false"); diff --git a/test/unit/api/python/test_solver.py b/test/unit/api/python/test_solver.py index 9b89f7187..24f180c7b 100644 --- a/test/unit/api/python/test_solver.py +++ b/test/unit/api/python/test_solver.py @@ -1259,6 +1259,27 @@ def test_get_unsat_core3(solver): res = solver.checkSat() assert res.isUnsat() +def test_learned_literals(solver): + solver.setOption("produce-learned-literals", "true") + with pytest.raises(RuntimeError): + solver.getLearnedLiterals() + solver.checkSat() + solver.getLearnedLiterals() + +def test_learned_literals2(solver): + solver.setOption("produce-learned-literals", "true") + intSort = solver.getIntegerSort() + x = solver.mkConst(intSort, "x") + y = solver.mkConst(intSort, "y") + zero = solver.mkInteger(0) + ten = solver.mkInteger(10) + f0 = solver.mkTerm(Kind.Geq, x, ten) + f1 = solver.mkTerm(Kind.Or, solver.mkTerm(Kind.Geq, zero, x), solver.mkTerm(Kind.Geq, y, zero)) + solver.assertFormula(f0) + solver.assertFormula(f1) + solver.checkSat() + solver.getLearnedLiterals() + def test_get_value1(solver): solver.setOption("produce-models", "false") -- 2.30.2