Split abduction solver from SmtEngine (#4733)
[cvc5.git] / src / smt / smt_engine.cpp
index 12e0f443e133905fe7f123fb1218bf6f4e1ec17f..380878e41638627afd3d33251661f70af402e20f 100644 (file)
@@ -81,6 +81,7 @@
 #include "proof/theory_proof.h"
 #include "proof/unsat_core.h"
 #include "prop/prop_engine.h"
+#include "smt/abduction_solver.h"
 #include "smt/command.h"
 #include "smt/command_list.h"
 #include "smt/defined_function.h"
 #include "theory/logic_info.h"
 #include "theory/quantifiers/fun_def_process.h"
 #include "theory/quantifiers/single_inv_partition.h"
-#include "theory/quantifiers/sygus/sygus_abduct.h"
 #include "theory/quantifiers/sygus/synth_engine.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/quantifiers_engine.h"
@@ -642,6 +642,7 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr)
       d_proofManager(nullptr),
       d_rewriter(new theory::Rewriter()),
       d_definedFunctions(nullptr),
+      d_abductSolver(nullptr),
       d_assertionList(nullptr),
       d_assignments(nullptr),
       d_modelGlobalCommands(),
@@ -817,6 +818,12 @@ void SmtEngine::finishInit()
   }
   d_dumpCommands.clear();
 
+  // subsolvers
+  if (options::produceAbducts())
+  {
+    d_abductSolver.reset(new AbductionSolver(this));
+  }
+
   PROOF( ProofManager::currentPM()->setLogic(d_logic); );
   PROOF({
       for(TheoryId id = theory::THEORY_FIRST; id < theory::THEORY_LAST; ++id) {
@@ -2161,18 +2168,17 @@ Expr SmtEngine::simplify(const Expr& ex)
   return n.toExpr();
 }
 
-Expr SmtEngine::expandDefinitions(const Expr& ex)
+Node SmtEngine::expandDefinitions(const Node& ex)
 {
   d_private->spendResource(ResourceManager::Resource::PreprocessStep);
 
-  Assert(ex.getExprManager() == d_exprManager);
   SmtScope smts(this);
   finalOptionsAreSet();
   doPendingPops();
   Trace("smt") << "SMT expandDefinitions(" << ex << ")" << endl;
 
   // Substitute out any abstract values in ex.
-  Expr e = d_private->substituteAbstractValues(Node::fromExpr(ex)).toExpr();
+  Node e = d_private->substituteAbstractValues(ex);
   if(options::typeChecking()) {
     // Ensure expr is type-checked at this point.
     e.getType(true);
@@ -2180,10 +2186,10 @@ Expr SmtEngine::expandDefinitions(const Expr& ex)
 
   unordered_map<Node, Node, NodeHashFunction> cache;
   Node n = d_private->getProcessAssertions()->expandDefinitions(
-      Node::fromExpr(e), cache, /* expandOnly = */ true);
-  n = postprocess(n, TypeNode::fromType(e.getType()));
+      e, cache, /* expandOnly = */ true);
+  n = postprocess(n, e.getType());
 
-  return n.toExpr();
+  return n;
 }
 
 // TODO(#1108): Simplify the error reporting of this method.
@@ -2985,69 +2991,11 @@ void SmtEngine::checkInterpol(Expr interpol,
 {
 }
 
-void SmtEngine::checkAbduct(Expr a)
+void SmtEngine::checkAbduct(Node a)
 {
   Assert(a.getType().isBoolean());
-  Trace("check-abduct") << "SmtEngine::checkAbduct: get expanded assertions"
-                        << std::endl;
-
-  std::vector<Expr> asserts = getExpandedAssertions();
-  asserts.push_back(a);
-
-  // two checks: first, consistent with assertions, second, implies negated goal
-  // is unsatisfiable.
-  for (unsigned j = 0; j < 2; j++)
-  {
-    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
-                          << ": make new SMT engine" << std::endl;
-    // Start new SMT engine to check solution
-    SmtEngine abdChecker(d_exprManager, &d_options);
-    abdChecker.setIsInternalSubsolver();
-    abdChecker.setLogic(getLogicInfo());
-    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
-                          << ": asserting formulas" << std::endl;
-    for (const Expr& e : asserts)
-    {
-      abdChecker.assertFormula(e);
-    }
-    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
-                          << ": check the assertions" << std::endl;
-    Result r = abdChecker.checkSat();
-    Trace("check-abduct") << "SmtEngine::checkAbduct: phase " << j
-                          << ": result is " << r << endl;
-    std::stringstream serr;
-    bool isError = false;
-    if (j == 0)
-    {
-      if (r.asSatisfiabilityResult().isSat() != Result::SAT)
-      {
-        isError = true;
-        serr << "SmtEngine::checkAbduct(): produced solution cannot be shown "
-                "to be consisconsistenttent with assertions, result was "
-             << r;
-      }
-      Trace("check-abduct")
-          << "SmtEngine::checkAbduct: goal is " << d_abdConj << std::endl;
-      // add the goal to the set of assertions
-      Assert(!d_abdConj.isNull());
-      asserts.push_back(d_abdConj);
-    }
-    else
-    {
-      if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
-      {
-        isError = true;
-        serr << "SmtEngine::checkAbduct(): negated goal cannot be shown "
-                "unsatisfiable with produced solution, result was "
-             << r;
-      }
-    }
-    // did we get an unexpected result?
-    if (isError)
-    {
-      InternalError() << serr.str();
-    }
-  }
+  // check it with the abduction solver
+  return d_abductSolver->checkAbduct(a);
 }
 
 // TODO(#1108): Simplify the error reporting of this method.
@@ -3217,54 +3165,11 @@ bool SmtEngine::getInterpol(const Expr& conj, Expr& interpol)
   return getInterpol(conj, grammarType, interpol);
 }
 
-bool SmtEngine::getAbduct(const Expr& conj, const Type& grammarType, Expr& abd)
+bool SmtEngine::getAbduct(const Node& conj,
+                          const TypeNode& grammarType,
+                          Node& abd)
 {
-  SmtScope smts(this);
-
-  if (!options::produceAbducts())
-  {
-    const char* msg = "Cannot get abduct when produce-abducts options is off.";
-    throw ModalException(msg);
-  }
-  Trace("sygus-abduct") << "SmtEngine::getAbduct: conjecture " << conj
-                        << std::endl;
-  std::vector<Expr> easserts = getExpandedAssertions();
-  std::vector<Node> axioms;
-  for (unsigned i = 0, size = easserts.size(); i < size; i++)
-  {
-    axioms.push_back(Node::fromExpr(easserts[i]));
-  }
-  std::vector<Node> asserts(axioms.begin(), axioms.end());
-  // negate the conjecture
-  Node conjn = Node::fromExpr(conj);
-  // must expand definitions
-  std::unordered_map<Node, Node, NodeHashFunction> cache;
-  conjn = d_private->getProcessAssertions()->expandDefinitions(conjn, cache);
-  // now negate
-  conjn = conjn.negate();
-  d_abdConj = conjn.toExpr();
-  asserts.push_back(conjn);
-  std::string name("A");
-  Node aconj = theory::quantifiers::SygusAbduct::mkAbductionConjecture(
-      name, asserts, axioms, TypeNode::fromType(grammarType));
-  // should be a quantified conjecture with one function-to-synthesize
-  Assert(aconj.getKind() == kind::FORALL && aconj[0].getNumChildren() == 1);
-  // remember the abduct-to-synthesize
-  d_sssf = aconj[0][0].toExpr();
-  Trace("sygus-abduct") << "SmtEngine::getAbduct: made conjecture : " << aconj
-                        << ", solving for " << d_sssf << std::endl;
-  // we generate a new smt engine to do the abduction query
-  d_subsolver.reset(
-      new SmtEngine(NodeManager::currentNM()->toExprManager(), &d_options));
-  d_subsolver->setIsInternalSubsolver();
-  // get the logic
-  LogicInfo l = d_logic.getUnlockedCopy();
-  // enable everything needed for sygus
-  l.enableSygus();
-  d_subsolver->setLogic(l);
-  // assert the abduction query
-  d_subsolver->assertFormula(aconj.toExpr());
-  if (getAbductInternal(abd))
+  if (d_abductSolver->getAbduct(conj, grammarType, abd))
   {
     // successfully generated an abduct, update to abduct state
     d_smtMode = SMT_MODE_ABDUCT;
@@ -3275,68 +3180,9 @@ bool SmtEngine::getAbduct(const Expr& conj, const Type& grammarType, Expr& abd)
   return false;
 }
 
-bool SmtEngine::getAbductInternal(Expr& abd)
+bool SmtEngine::getAbduct(const Node& conj, Node& abd)
 {
-  // should have initialized the subsolver by now
-  Assert(d_subsolver != nullptr);
-  Trace("sygus-abduct") << "  SmtEngine::getAbduct check sat..." << std::endl;
-  Result r = d_subsolver->checkSat();
-  Trace("sygus-abduct") << "  SmtEngine::getAbduct result: " << r << std::endl;
-  if (r.asSatisfiabilityResult().isSat() == Result::UNSAT)
-  {
-    // get the synthesis solution
-    std::map<Expr, Expr> sols;
-    d_subsolver->getSynthSolutions(sols);
-    Assert(sols.size() == 1);
-    std::map<Expr, Expr>::iterator its = sols.find(d_sssf);
-    if (its != sols.end())
-    {
-      Trace("sygus-abduct")
-          << "SmtEngine::getAbduct: solution is " << its->second << std::endl;
-      Node abdn = Node::fromExpr(its->second);
-      if (abdn.getKind() == kind::LAMBDA)
-      {
-        abdn = abdn[1];
-      }
-      // get the grammar type for the abduct
-      Node af = Node::fromExpr(d_sssf);
-      Node agdtbv = af.getAttribute(theory::SygusSynthFunVarListAttribute());
-      Assert(!agdtbv.isNull());
-      Assert(agdtbv.getKind() == kind::BOUND_VAR_LIST);
-      // convert back to original
-      // must replace formal arguments of abd with the free variables in the
-      // input problem that they correspond to.
-      std::vector<Node> vars;
-      std::vector<Node> syms;
-      SygusVarToTermAttribute sta;
-      for (const Node& bv : agdtbv)
-      {
-        vars.push_back(bv);
-        syms.push_back(bv.hasAttribute(sta) ? bv.getAttribute(sta) : bv);
-      }
-      abdn =
-          abdn.substitute(vars.begin(), vars.end(), syms.begin(), syms.end());
-
-      // convert to expression
-      abd = abdn.toExpr();
-
-      // if check abducts option is set, we check the correctness
-      if (options::checkAbducts())
-      {
-        checkAbduct(abd);
-      }
-      return true;
-    }
-    Trace("sygus-abduct") << "SmtEngine::getAbduct: could not find solution!"
-                          << std::endl;
-    throw RecoverableModalException("Could not find solution for get-abduct.");
-  }
-  return false;
-}
-
-bool SmtEngine::getAbduct(const Expr& conj, Expr& abd)
-{
-  Type grammarType;
+  TypeNode grammarType;
   return getAbduct(conj, grammarType, abd);
 }