Add get-difficulty to the API (#7194)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 14 Sep 2021 23:29:22 +0000 (18:29 -0500)
committerGitHub <noreply@github.com>
Tue, 14 Sep 2021 23:29:22 +0000 (23:29 +0000)
Adds smt2 parsing, printing and API support for get-difficulty. Adds unit tests.

16 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/main/command_executor.cpp
src/options/main_options.toml
src/parser/smt2/Smt2.g
src/printer/printer.cpp
src/printer/printer.h
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
src/prop/minisat/core/Solver.cc
src/smt/command.cpp
src/smt/command.h
src/smt/set_defaults.cpp
src/smt/smt_engine.cpp
src/smt/smt_engine.h
test/unit/api/solver_black.cpp

index 03c464e08703c5cea1871871c8fed2ae6977519d..6aa556ed8b3e6c4fd235c0d374f593c01f26025d 100644 (file)
@@ -7294,12 +7294,34 @@ std::vector<Term> Solver::getUnsatCore(void) const
   CVC5_API_TRY_CATCH_END;
 }
 
+std::map<Term, Term> Solver::getDifficulty() const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  NodeManagerScope scope(getNodeManager());
+  CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getSmtMode() == SmtMode::UNSAT
+                             || d_smtEngine->getSmtMode() == SmtMode::SAT
+                             || d_smtEngine->getSmtMode()
+                                    == SmtMode::SAT_UNKNOWN)
+      << "Cannot get difficulty unless after a UNSAT, SAT or unknown response.";
+  //////// all checks before this line
+  std::map<Term, Term> res;
+  std::map<Node, Node> dmap;
+  d_smtEngine->getDifficultyMap(dmap);
+  for (const std::pair<const Node, Node>& d : dmap)
+  {
+    res[Term(this, d.first)] = Term(this, d.second);
+  }
+  return res;
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 std::string Solver::getProof(void) const
 {
   CVC5_API_TRY_CATCH_BEGIN;
   NodeManagerScope scope(getNodeManager());
   CVC5_API_CHECK(d_smtEngine->getOptions().smt.produceProofs)
-      << "Cannot get proof explicitly enabled (try --prooduce-proofs)";
+      << "Cannot get proof explicitly enabled (try --produce-proofs)";
   CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getSmtMode() == SmtMode::UNSAT)
       << "Cannot get proof unless in unsat mode.";
   return d_smtEngine->getProof();
index 684b891148496025d6a1f04471c72553d4e70fca..4f1c60814fb7c2ad058c96da9f7f6d90b7ba6e4a 100644 (file)
@@ -3895,6 +3895,16 @@ class CVC5_EXPORT Solver
    */
   std::vector<Term> getUnsatCore() const;
 
+  /**
+   * Get a difficulty estimate for an asserted formula. This method is
+   * intended to be called immediately after any response to a checkSat.
+   *
+   * @return a map from (a subset of) the input assertions to a real value that
+   * is an estimate of how difficult each assertion was to solve. Unmentioned
+   * assertions can be assumed to have zero difficulty.
+   */
+  std::map<Term, Term> getDifficulty() const;
+
   /**
    * Get the refutation proof
    * SMT-LIB:
index 6b03f82e303881b271f185de0bb57ca2bcc40c3c..2c57414915d821bb19eb2f9c4d6132bc04dc4c56 100644 (file)
@@ -169,6 +169,12 @@ bool CommandExecutor::doCommandSingleton(Command* cmd)
       getterCommands.emplace_back(new GetUnsatCoreCommand());
     }
 
+    if (d_solver->getOptionInfo("dump-difficulty").boolValue()
+        && (isResultUnsat || isResultSat || res.isSatUnknown()))
+    {
+      getterCommands.emplace_back(new GetDifficultyCommand());
+    }
+
     if (!getterCommands.empty()) {
       // set no time limit during dumping if applicable
       if (d_solver->getOptionInfo("force-no-limit-cpu-while-dump").boolValue())
index fc03de16721a65b40cfad1ed9eaa2cd29dbc338b..55bea41f0a14978edfc0f8d9b16cc94e9bf2321c 100644 (file)
@@ -134,6 +134,14 @@ name   = "Driver"
   default    = "false"
   help       = "dump the full unsat core, including unlabeled assertions"
 
+[[option]]
+  name       = "dumpDifficulty"
+  category   = "regular"
+  long       = "dump-difficulty"
+  type       = "bool"
+  default    = "false"
+  help       = "dump the difficulty measure after every response to check-sat"
+
 [[option]]
   name       = "forceNoLimitCpuWhileDump"
   category   = "regular"
index 2c391169c2c94d5920297a86b2cd84978fdc2e9f..8f9e01c6cf545b013349149080cedb690c3d71e4 100644 (file)
@@ -408,6 +408,9 @@ command [std::unique_ptr<cvc5::Command>* cmd]
   | /* get-unsat-core */
     GET_UNSAT_CORE_TOK { PARSER_STATE->checkThatLogicIsSet(); }
     { cmd->reset(new GetUnsatCoreCommand); }
+  | /* get-difficulty */
+    GET_DIFFICULTY_TOK { PARSER_STATE->checkThatLogicIsSet(); }
+    { cmd->reset(new GetDifficultyCommand); }
   | /* push */
     PUSH_TOK { PARSER_STATE->checkThatLogicIsSet(); }
     { if( PARSER_STATE->sygus() ){
@@ -1253,7 +1256,7 @@ simpleSymbolicExprNoKeyword[std::string& s]
         | DEFINE_FUN_TOK | DEFINE_FUN_REC_TOK | DEFINE_FUNS_REC_TOK
         | DEFINE_SORT_TOK | GET_VALUE_TOK | GET_ASSIGNMENT_TOK
         | GET_ASSERTIONS_TOK | GET_PROOF_TOK | GET_UNSAT_ASSUMPTIONS_TOK
-        | GET_UNSAT_CORE_TOK | EXIT_TOK
+        | GET_UNSAT_CORE_TOK | GET_DIFFICULTY_TOK | EXIT_TOK
         | RESET_TOK | RESET_ASSERTIONS_TOK | SET_LOGIC_TOK | SET_INFO_TOK
         | GET_INFO_TOK | SET_OPTION_TOK | GET_OPTION_TOK | PUSH_TOK | POP_TOK
         | DECLARE_DATATYPES_TOK | GET_MODEL_TOK | ECHO_TOK | SIMPLIFY_TOK)
@@ -2222,6 +2225,7 @@ GET_ASSERTIONS_TOK : 'get-assertions';
 GET_PROOF_TOK : 'get-proof';
 GET_UNSAT_ASSUMPTIONS_TOK : 'get-unsat-assumptions';
 GET_UNSAT_CORE_TOK : 'get-unsat-core';
+GET_DIFFICULTY_TOK : 'get-difficulty';
 EXIT_TOK : 'exit';
 RESET_TOK : 'reset';
 RESET_ASSERTIONS_TOK : 'reset-assertions';
index 01fa7a9fdbd6d9d5d873750ebcc9a672f86fefe7..f1ad9212f45834dca2df5f62416c6e9802af21e8 100644 (file)
@@ -430,6 +430,11 @@ void Printer::toStreamCmdGetUnsatCore(std::ostream& out) const
   printUnknownCommand(out, "get-unsat-core");
 }
 
+void Printer::toStreamCmdGetDifficulty(std::ostream& out) const
+{
+  printUnknownCommand(out, "get-difficulty");
+}
+
 void Printer::toStreamCmdGetAssertions(std::ostream& out) const
 {
   printUnknownCommand(out, "get-assertions");
index 499a9398fa3c09fa83a8a971c9399d674c653b4a..b657a6dfafe933335706fc1f4f4e48137e369c37 100644 (file)
@@ -213,6 +213,9 @@ class Printer
   /** Print get-unsat-core command */
   virtual void toStreamCmdGetUnsatCore(std::ostream& out) const;
 
+  /** Print get-difficulty command */
+  virtual void toStreamCmdGetDifficulty(std::ostream& out) const;
+
   /** Print get-assertions command */
   virtual void toStreamCmdGetAssertions(std::ostream& out) const;
 
index 07c5b10d819cf292a8bb772641e28ffceae422e2..b92d8fa431fc7842729bbd7662a1c4813b4d30e0 100644 (file)
@@ -1612,6 +1612,11 @@ void Smt2Printer::toStreamCmdGetUnsatCore(std::ostream& out) const
   out << "(get-unsat-core)" << std::endl;
 }
 
+void Smt2Printer::toStreamCmdGetDifficulty(std::ostream& out) const
+{
+  out << "(get-difficulty)" << std::endl;
+}
+
 void Smt2Printer::toStreamCmdSetBenchmarkStatus(std::ostream& out,
                                                 Result::Sat status) const
 {
index fd7e0c7ac75fb8503e3633c16ec71d10bd87ea0b..839c016af07559ad4a5e3e76fd2528b3de8afe73 100644 (file)
@@ -165,6 +165,9 @@ class Smt2Printer : public cvc5::Printer
   /** Print get-unsat-core command */
   void toStreamCmdGetUnsatCore(std::ostream& out) const override;
 
+  /** Print get-difficulty command */
+  void toStreamCmdGetDifficulty(std::ostream& out) const override;
+
   /** Print get-assertions command */
   void toStreamCmdGetAssertions(std::ostream& out) const override;
 
index 6f99a47f052835c786680ef1cb23e4a32d3b4c3d..3a2a79ddd12ce8814df0bb35b4130a3bf9ae90c7 100644 (file)
@@ -2188,7 +2188,8 @@ bool Solver::isProofEnabled() const { return d_pfManager != nullptr; }
 bool Solver::needProof() const
 {
   return isProofEnabled()
-         && options::unsatCoresMode() != options::UnsatCoresMode::ASSUMPTIONS;
+         && options::unsatCoresMode() != options::UnsatCoresMode::ASSUMPTIONS
+         && options::unsatCoresMode() != options::UnsatCoresMode::PP_ONLY;
 }
 
 }  // namespace Minisat
index 4b04abcb26b2d25a0a57908600f8df3f7bc16947..34b2c9692728c1b35611029a09677e32f84ac228 100644 (file)
@@ -2411,6 +2411,86 @@ void GetUnsatCoreCommand::toStream(std::ostream& out,
   Printer::getPrinter(language)->toStreamCmdGetUnsatCore(out);
 }
 
+/* -------------------------------------------------------------------------- */
+/* class GetDifficultyCommand */
+/* -------------------------------------------------------------------------- */
+
+GetDifficultyCommand::GetDifficultyCommand() : d_sm(nullptr) {}
+void GetDifficultyCommand::invoke(api::Solver* solver, SymbolManager* sm)
+{
+  try
+  {
+    d_sm = sm;
+    d_result = solver->getDifficulty();
+
+    d_commandStatus = CommandSuccess::instance();
+  }
+  catch (api::CVC5ApiRecoverableException& e)
+  {
+    d_commandStatus = new CommandRecoverableFailure(e.what());
+  }
+  catch (exception& e)
+  {
+    d_commandStatus = new CommandFailure(e.what());
+  }
+}
+
+void GetDifficultyCommand::printResult(std::ostream& out,
+                                       uint32_t verbosity) const
+{
+  if (!ok())
+  {
+    this->Command::printResult(out, verbosity);
+  }
+  else
+  {
+    out << "(" << std::endl;
+    for (const std::pair<const api::Term, api::Term>& d : d_result)
+    {
+      out << "(";
+      // use name if it has one
+      std::string name;
+      if (d_sm->getExpressionName(d.first, name, true))
+      {
+        out << name;
+      }
+      else
+      {
+        out << d.first;
+      }
+      out << " " << d.second << ")" << std::endl;
+    }
+    out << ")" << std::endl;
+  }
+}
+
+const std::map<api::Term, api::Term>& GetDifficultyCommand::getDifficultyMap()
+    const
+{
+  return d_result;
+}
+
+Command* GetDifficultyCommand::clone() const
+{
+  GetDifficultyCommand* c = new GetDifficultyCommand;
+  c->d_sm = d_sm;
+  c->d_result = d_result;
+  return c;
+}
+
+std::string GetDifficultyCommand::getCommandName() const
+{
+  return "get-difficulty";
+}
+
+void GetDifficultyCommand::toStream(std::ostream& out,
+                                    int toDepth,
+                                    size_t dag,
+                                    Language language) const
+{
+  Printer::getPrinter(language)->toStreamCmdGetDifficulty(out);
+}
+
 /* -------------------------------------------------------------------------- */
 /* class GetAssertionsCommand                                                 */
 /* -------------------------------------------------------------------------- */
index 7587aaa636b396b1a3139f03d8dd94fe13dca224..989e38ef0e46d6fefaf9c0ea142b5705481a862c 100644 (file)
@@ -1190,6 +1190,29 @@ class CVC5_EXPORT GetUnsatCoreCommand : public Command
   std::vector<api::Term> d_result;
 }; /* class GetUnsatCoreCommand */
 
+class CVC5_EXPORT GetDifficultyCommand : public Command
+{
+ public:
+  GetDifficultyCommand();
+  const std::map<api::Term, api::Term>& getDifficultyMap() const;
+
+  void invoke(api::Solver* solver, SymbolManager* sm) override;
+  void printResult(std::ostream& out, uint32_t verbosity = 2) const override;
+
+  Command* clone() const override;
+  std::string getCommandName() const override;
+  void toStream(std::ostream& out,
+                int toDepth = -1,
+                size_t dag = 1,
+                Language language = Language::LANG_AUTO) const override;
+
+ protected:
+  /** The symbol manager we were invoked with */
+  SymbolManager* d_sm;
+  /** the result of the get difficulty call */
+  std::map<api::Term, api::Term> d_result;
+};
+
 class CVC5_EXPORT GetAssertionsCommand : public Command
 {
  protected:
index a226de8076e17e9fc2e2de29f5702ec3ff219935..4d5e898feb4306be12a2d19a8ddee3d2c8daa88c 100644 (file)
@@ -81,6 +81,10 @@ void SetDefaults::setDefaultsPre(Options& opts)
   {
     opts.driver.dumpUnsatCores = true;
   }
+  if (opts.driver.dumpDifficulty)
+  {
+    opts.smt.produceDifficulty = true;
+  }
   if (opts.smt.produceDifficulty)
   {
     if (opts.smt.unsatCoresMode == options::UnsatCoresMode::OFF)
index 46e83e9e7a804248a74224dff9def8249854c290..aa7534f26619f9fa842c41a617364cfbf78dd294 100644 (file)
@@ -1774,6 +1774,29 @@ std::vector<Node> SmtEngine::getAssertions()
   return getAssertionsInternal();
 }
 
+void SmtEngine::getDifficultyMap(std::map<Node, Node>& dmap)
+{
+  Trace("smt") << "SMT getDifficultyMap()\n";
+  SmtScope smts(this);
+  finishInit();
+  if (Dump.isOn("benchmark"))
+  {
+    getPrinter().toStreamCmdGetDifficulty(d_env->getDumpOut());
+  }
+  if (!d_env->getOptions().smt.produceDifficulty)
+  {
+    throw ModalException(
+        "Cannot get difficulty when difficulty option is off.");
+  }
+  // the prop engine has the proof of false
+  Assert(d_pfManager);
+  // get difficulty map from theory engine first
+  TheoryEngine* te = getTheoryEngine();
+  te->getDifficultyMap(dmap);
+  // then ask proof manager to translate dmap in terms of the input
+  d_pfManager->translateDifficultyMap(dmap, *d_asserts);
+}
+
 void SmtEngine::push()
 {
   SmtScope smts(this);
index 9f17fa27ebd5d93f9c6583f01c86c22045b89ff4..a06b2fd6123d9904a58763e81a23d86958dc9748 100644 (file)
@@ -701,6 +701,12 @@ class CVC5_EXPORT SmtEngine
    */
   std::vector<Node> getAssertions();
 
+  /**
+   * Get difficulty map, which populates dmap, mapping input assertions
+   * to a value that estimates their difficulty for solving the current problem.
+   */
+  void getDifficultyMap(std::map<Node, Node>& dmap);
+
   /**
    * Push a user-level context.
    * throw@ ModalException, LogicException, UnsafeInterruptException
index 9042209e28765fec92ace729b16c760bdf1f4b2a..fe8cf87baffe63fb12f577a8346d56f6d8299569 100644 (file)
@@ -1467,6 +1467,42 @@ TEST_F(TestApiBlackSolver, getUnsatCoreAndProof)
   ASSERT_NO_THROW(d_solver.getProof());
 }
 
+TEST_F(TestApiBlackSolver, getDifficulty)
+{
+  d_solver.setOption("produce-difficulty", "true");
+  // cannot ask before a check sat
+  ASSERT_THROW(d_solver.getDifficulty(), CVC5ApiException);
+  d_solver.checkSat();
+  ASSERT_NO_THROW(d_solver.getDifficulty());
+}
+
+TEST_F(TestApiBlackSolver, getDifficulty2)
+{
+  d_solver.checkSat();
+  // option is not set
+  ASSERT_THROW(d_solver.getDifficulty(), CVC5ApiException);
+}
+
+TEST_F(TestApiBlackSolver, getDifficulty3)
+{
+  d_solver.setOption("produce-difficulty", "true");
+  Sort intSort = d_solver.getIntegerSort();
+  Term x = d_solver.mkConst(intSort, "x");
+  Term zero = d_solver.mkInteger(0);
+  Term ten = d_solver.mkInteger(10);
+  Term f0 = d_solver.mkTerm(GEQ, x, ten);
+  Term f1 = d_solver.mkTerm(GEQ, zero, x);
+  d_solver.checkSat();
+  std::map<Term, Term> dmap;
+  ASSERT_NO_THROW(dmap = d_solver.getDifficulty());
+  // difficulty should map assertions to integer values
+  for (const std::pair<Term, Term>& t : dmap)
+  {
+    ASSERT_TRUE(t.first == f0 || t.first == f1);
+    ASSERT_TRUE(t.second.getKind() == CONST_RATIONAL);
+  }
+}
+
 TEST_F(TestApiBlackSolver, getValue1)
 {
   d_solver.setOption("produce-models", "false");