Split SygusSolver from SmtEngine (#4891)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 18 Aug 2020 16:41:46 +0000 (11:41 -0500)
committerGitHub <noreply@github.com>
Tue, 18 Aug 2020 16:41:46 +0000 (11:41 -0500)
This is the solver for standard SyGuS queries. Notice it now depends only on SmtSolver and not SmtEngine.

This PR updates Expr -> Node for the sygus interface in SmtEngine.

SmtEnginePrivate is no longer needed and is deleted with this PR.

src/CMakeLists.txt
src/api/cvc4cpp.cpp
src/preprocessing/passes/sygus_inference.cpp
src/smt/abduction_solver.cpp
src/smt/command.cpp
src/smt/smt_engine.cpp
src/smt/smt_engine.h
src/smt/sygus_solver.cpp [new file with mode: 0644]
src/smt/sygus_solver.h [new file with mode: 0644]
src/theory/quantifiers/sygus/sygus_interpol.cpp

index 62a9b35d0e0ed1ed9211f892a07ac22a0b92d532..10cf23fdf23bf64a4ce544b875895a3bbcdcb48d 100644 (file)
@@ -271,6 +271,8 @@ libcvc4_add_sources(
   smt/smt_solver.h
   smt/smt_statistics_registry.cpp
   smt/smt_statistics_registry.h
+  smt/sygus_solver.cpp
+  smt/sygus_solver.h
   smt/term_formula_removal.cpp
   smt/term_formula_removal.h
   smt/update_ostream.h
index 5ccb4c6c167547fe3475784a08f7ca943fcfe3b4..150f84301876f82fa6e6c5f6ff782e14189eb8e6 100644 (file)
@@ -5157,7 +5157,7 @@ Term Solver::mkSygusVar(Sort sort, const std::string& symbol) const
   Expr res = d_exprMgr->mkBoundVar(symbol, *sort.d_type);
   (void)res.getType(true); /* kick off type checking */
 
-  d_smtEngine->declareSygusVar(symbol, res, *sort.d_type);
+  d_smtEngine->declareSygusVar(symbol, res, TypeNode::fromType(*sort.d_type));
 
   return Term(this, res);
 
@@ -5279,14 +5279,21 @@ Term Solver::synthFunHelper(const std::string& symbol,
                      ? *sort.d_type
                      : d_exprMgr->mkFunctionType(varTypes, *sort.d_type);
 
-  Expr fun = d_exprMgr->mkBoundVar(symbol, funType);
+  Node fun = getNodeManager()->mkBoundVar(symbol, TypeNode::fromType(funType));
   (void)fun.getType(true); /* kick off type checking */
 
-  d_smtEngine->declareSynthFun(symbol,
-                               fun,
-                               g == nullptr ? funType : *g->resolve().d_type,
-                               isInv,
-                               termVectorToExprs(boundVars));
+  std::vector<Node> bvns;
+  for (const Term& t : boundVars)
+  {
+    bvns.push_back(*t.d_node);
+  }
+
+  d_smtEngine->declareSynthFun(
+      symbol,
+      fun,
+      TypeNode::fromType(g == nullptr ? funType : *g->resolve().d_type),
+      isInv,
+      bvns);
 
   return Term(this, fun);
 
@@ -5373,13 +5380,12 @@ Term Solver::getSynthSolution(Term term) const
   CVC4_API_ARG_CHECK_NOT_NULL(term);
   CVC4_API_SOLVER_CHECK_TERM(term);
 
-  std::map<CVC4::Expr, CVC4::Expr> map;
+  std::map<CVC4::Node, CVC4::Node> map;
   CVC4_API_CHECK(d_smtEngine->getSynthSolutions(map))
       << "The solver is not in a state immediately preceeded by a "
          "successful call to checkSynth";
 
-  std::map<CVC4::Expr, CVC4::Expr>::const_iterator it =
-      map.find(term.d_node->toExpr());
+  std::map<CVC4::Node, CVC4::Node>::const_iterator it = map.find(*term.d_node);
 
   CVC4_API_CHECK(it != map.cend()) << "Synth solution not found for given term";
 
@@ -5403,7 +5409,7 @@ std::vector<Term> Solver::getSynthSolutions(
         << "non-null term";
   }
 
-  std::map<CVC4::Expr, CVC4::Expr> map;
+  std::map<CVC4::Node, CVC4::Node> map;
   CVC4_API_CHECK(d_smtEngine->getSynthSolutions(map))
       << "The solver is not in a state immediately preceeded by a "
          "successful call to checkSynth";
@@ -5413,8 +5419,8 @@ std::vector<Term> Solver::getSynthSolutions(
 
   for (size_t i = 0, n = terms.size(); i < n; ++i)
   {
-    std::map<CVC4::Expr, CVC4::Expr>::const_iterator it =
-        map.find(terms[i].d_node->toExpr());
+    std::map<CVC4::Node, CVC4::Node>::const_iterator it =
+        map.find(*terms[i].d_node);
 
     CVC4_API_CHECK(it != map.cend())
         << "Synth solution not found for term at index " << i;
index 7336ac159c395353b93965869b9fb665da487668..d44321a35c42d1f86d45f97a6b724e7d48c4c2a7 100644 (file)
@@ -313,25 +313,25 @@ bool SygusInference::solveSygus(std::vector<Node>& assertions,
     return false;
   }
   // get the synthesis solutions
-  std::map<Expr, Expr> synth_sols;
+  std::map<Node, Node> synth_sols;
   rrSygus->getSynthSolutions(synth_sols);
 
   std::vector<Node> final_ff;
   std::vector<Node> final_ff_sol;
-  for (std::map<Expr, Expr>::iterator it = synth_sols.begin();
+  for (std::map<Node, Node>::iterator it = synth_sols.begin();
        it != synth_sols.end();
        ++it)
   {
     Trace("sygus-infer") << "  synth sol : " << it->first << " -> "
                          << it->second << std::endl;
-    Node ffv = Node::fromExpr(it->first);
+    Node ffv = it->first;
     std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
     // all synthesis solutions should correspond to a variable we introduced
     Assert(itffv != ff_var_to_ff.end());
     if (itffv != ff_var_to_ff.end())
     {
       Node ff = itffv->second;
-      Node body2 = Node::fromExpr(it->second);
+      Node body2 = it->second;
       Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
       funs.push_back(ff);
       sols.push_back(body2);
index 01e2a4f0fe0e0c82a1fd34cd81da378819a1605d..2a6346c18d214562628573b530aa4cdb3a362efa 100644 (file)
@@ -89,16 +89,15 @@ bool AbductionSolver::getAbductInternal(Node& abd)
   if (r.asSatisfiabilityResult().isSat() == Result::UNSAT)
   {
     // get the synthesis solution
-    std::map<Expr, Expr> sols;
+    std::map<Node, Node> sols;
     d_subsolver->getSynthSolutions(sols);
     Assert(sols.size() == 1);
-    Expr essf = d_sssf.toExpr();
-    std::map<Expr, Expr>::iterator its = sols.find(essf);
+    std::map<Node, Node>::iterator its = sols.find(d_sssf);
     if (its != sols.end())
     {
       Trace("sygus-abduct")
           << "SmtEngine::getAbduct: solution is " << its->second << std::endl;
-      abd = Node::fromExpr(its->second);
+      abd = its->second;
       if (abd.getKind() == kind::LAMBDA)
       {
         abd = abd[1];
index f5c997318ff4f63ca5a05143651ce9cd953eb11e..2383167a66610a29b5b68fbbf648e469c86ceb62 100644 (file)
@@ -577,7 +577,8 @@ void DeclareSygusVarCommand::invoke(SmtEngine* smtEngine)
 {
   try
   {
-    smtEngine->declareSygusVar(d_symbol, d_var, d_type);
+    smtEngine->declareSygusVar(
+        d_symbol, Node::fromExpr(d_var), TypeNode::fromType(d_type));
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
@@ -622,7 +623,8 @@ void DeclareSygusFunctionCommand::invoke(SmtEngine* smtEngine)
 {
   try
   {
-    smtEngine->declareSygusFunctionVar(d_symbol, d_func, d_type);
+    smtEngine->declareSygusVar(
+        d_symbol, Node::fromExpr(d_func), TypeNode::fromType(d_type));
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
@@ -687,13 +689,19 @@ void SynthFunCommand::invoke(SmtEngine* smtEngine)
 {
   try
   {
-    smtEngine->declareSynthFun(d_symbol,
-                               d_fun.getExpr(),
-                               d_grammar == nullptr
-                                   ? d_sort.getType()
-                                   : d_grammar->resolve().getType(),
-                               d_isInv,
-                               api::termVectorToExprs(d_vars));
+    std::vector<Node> vns;
+    for (const api::Term& t : d_vars)
+    {
+      vns.push_back(Node::fromExpr(t.getExpr()));
+    }
+    smtEngine->declareSynthFun(
+        d_symbol,
+        Node::fromExpr(d_fun.getExpr()),
+        TypeNode::fromType(d_grammar == nullptr
+                               ? d_sort.getType()
+                               : d_grammar->resolve().getType()),
+        d_isInv,
+        vns);
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
index e709406d87943b0effd3ebc033fc54626ee901ed..83f3cb5e0f92214ffe3bae5276d4a61974e69cd2 100644 (file)
@@ -96,6 +96,7 @@
 #include "smt/smt_engine_state.h"
 #include "smt/smt_engine_stats.h"
 #include "smt/smt_solver.h"
+#include "smt/sygus_solver.h"
 #include "smt/term_formula_removal.h"
 #include "smt/update_ostream.h"
 #include "smt_util/boolean_simplification.h"
@@ -140,56 +141,6 @@ extern const char* const plf_signatures;
 
 namespace smt {
 
-/**
- * This is an inelegant solution, but for the present, it will work.
- * The point of this is to separate the public and private portions of
- * the SmtEngine class, so that smt_engine.h doesn't
- * include "expr/node.h", which is a private CVC4 header (and can lead
- * to linking errors due to the improper inlining of non-visible symbols
- * into user code!).
- *
- * The "real" solution (that which is usually implemented) is to move
- * ALL the implementation to SmtEnginePrivate and maintain a
- * heap-allocated instance of it in SmtEngine.  SmtEngine (the public
- * one) becomes an "interface shell" which simply acts as a forwarder
- * of method calls.
- */
-class SmtEnginePrivate
-{
- public:
-
-  /* Finishes the initialization of the private portion of SMTEngine. */
-  void finishInit();
-
-  /*------------------- sygus utils ------------------*/
-  /**
-   * sygus variables declared (from "declare-var" and "declare-fun" commands)
-   *
-   * The SyGuS semantics for declared variables is that they are implicitly
-   * universally quantified in the constraints.
-   */
-  std::vector<Node> d_sygusVars;
-  /** sygus constraints */
-  std::vector<Node> d_sygusConstraints;
-  /** functions-to-synthesize */
-  std::vector<Node> d_sygusFunSymbols;
-  /**
-   * Whether we need to reconstruct the sygus conjecture.
-   */
-  CDO<bool> d_sygusConjectureStale;
-  /*------------------- end of sygus utils ------------------*/
-
- public:
-  SmtEnginePrivate(SmtEngine& smt)
-      : d_sygusConjectureStale(smt.getUserContext(), true)
-  {
-  }
-
-  ~SmtEnginePrivate()
-  {
-  }
-};/* class SmtEnginePrivate */
-
 }/* namespace CVC4::smt */
 
 SmtEngine::SmtEngine(ExprManager* em, Options* optr)
@@ -206,13 +157,13 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr)
       d_proofManager(nullptr),
       d_rewriter(new theory::Rewriter()),
       d_definedFunctions(nullptr),
+      d_sygusSolver(nullptr),
       d_abductSolver(nullptr),
       d_assignments(nullptr),
       d_defineCommands(),
       d_logic(),
       d_originalOptions(),
       d_isInternalSubsolver(false),
-      d_private(nullptr),
       d_statisticsRegistry(nullptr),
       d_stats(nullptr),
       d_resourceManager(nullptr),
@@ -245,7 +196,6 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr)
   d_optm.reset(new smt::OptionsManager(&d_options, d_resourceManager.get()));
   d_pp.reset(
       new smt::Preprocessor(*this, getUserContext(), *d_absValues.get()));
-  d_private.reset(new smt::SmtEnginePrivate(*this));
   // listen to node manager events
   d_nodeManager->subscribeEvents(d_snmListener.get());
   // listen to resource out
@@ -255,6 +205,8 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr)
   // make the SMT solver
   d_smtSolver.reset(
       new SmtSolver(*this, *d_state, d_resourceManager.get(), *d_pp, *d_stats));
+  // make the SyGuS solver
+  d_sygusSolver.reset(new SygusSolver(*d_smtSolver, *d_pp, getUserContext()));
 
   // The ProofManager is constructed before any other proof objects such as
   // SatProof and TheoryProofs. The TheoryProofEngine and the SatProof are
@@ -421,10 +373,11 @@ SmtEngine::~SmtEngine()
     d_exprNames.reset(nullptr);
     d_dumpm.reset(nullptr);
 
+    d_sygusSolver.reset(nullptr);
+
     d_smtSolver.reset(nullptr);
 
     d_stats.reset(nullptr);
-    d_private.reset(nullptr);
     d_nodeManager->unsubscribeEvents(d_snmListener.get());
     d_snmListener.reset(nullptr);
     d_routListener.reset(nullptr);
@@ -1145,57 +1098,26 @@ Result SmtEngine::assertFormula(const Node& formula, bool inUnsatCore)
    --------------------------------------------------------------------------
 */
 
-void SmtEngine::declareSygusVar(const std::string& id, Expr var, Type type)
+void SmtEngine::declareSygusVar(const std::string& id, Node var, TypeNode type)
 {
   SmtScope smts(this);
   finishInit();
-  d_private->d_sygusVars.push_back(Node::fromExpr(var));
-  Trace("smt") << "SmtEngine::declareSygusVar: " << var << "\n";
-  Dump("raw-benchmark") << DeclareSygusVarCommand(id, var, type);
-  // don't need to set that the conjecture is stale
-}
-
-void SmtEngine::declareSygusFunctionVar(const std::string& id,
-                                        Expr var,
-                                        Type type)
-{
-  SmtScope smts(this);
-  finishInit();
-  d_private->d_sygusVars.push_back(Node::fromExpr(var));
-  Trace("smt") << "SmtEngine::declareSygusFunctionVar: " << var << "\n";
-  Dump("raw-benchmark") << DeclareSygusVarCommand(id, var, type);
-
+  d_sygusSolver->declareSygusVar(id, var, type);
+  Dump("raw-benchmark") << DeclareSygusVarCommand(
+      id, var.toExpr(), type.toType());
   // don't need to set that the conjecture is stale
 }
 
 void SmtEngine::declareSynthFun(const std::string& id,
-                                Expr func,
-                                Type sygusType,
+                                Node func,
+                                TypeNode sygusType,
                                 bool isInv,
-                                const std::vector<Expr>& vars)
+                                const std::vector<Node>& vars)
 {
   SmtScope smts(this);
   finishInit();
   d_state->doPendingPops();
-  Node fn = Node::fromExpr(func);
-  d_private->d_sygusFunSymbols.push_back(fn);
-  if (!vars.empty())
-  {
-    Expr bvl = d_exprManager->mkExpr(kind::BOUND_VAR_LIST, vars);
-    std::vector<Expr> attr_val_bvl;
-    attr_val_bvl.push_back(bvl);
-    setUserAttribute("sygus-synth-fun-var-list", func, attr_val_bvl, "");
-  }
-  // whether sygus type encodes syntax restrictions
-  TypeNode stn = TypeNode::fromType(sygusType);
-  if (sygusType.isDatatype() && stn.getDType().isSygus())
-  {
-    Node sym = d_nodeManager->mkBoundVar("sfproxy", stn);
-    std::vector<Expr> attr_value;
-    attr_value.push_back(sym.toExpr());
-    setUserAttribute("sygus-synth-grammar", func, attr_value, "");
-  }
-  Trace("smt") << "SmtEngine::declareSynthFun: " << func << "\n";
+  d_sygusSolver->declareSynthFun(id, func, sygusType, isInv, vars);
 
   // !!! TEMPORARY: We cannot construct a SynthFunCommand since we cannot
   // construct a Term-level Grammar from a Node-level sygus TypeNode. Thus we
@@ -1203,187 +1125,50 @@ void SmtEngine::declareSynthFun(const std::string& id,
 
   if (Dump.isOn("raw-benchmark"))
   {
-    std::vector<Node> nodeVars;
-    nodeVars.reserve(vars.size());
-    for (const Expr& var : vars)
-    {
-      nodeVars.push_back(Node::fromExpr(var));
-    }
-
     std::stringstream ss;
 
     Printer::getPrinter(options::outputLanguage())
-        ->toStreamCmdSynthFun(
-            ss,
-            id,
-            nodeVars,
-            func.getType().isFunction()
-                ? TypeNode::fromType(func.getType()).getRangeType()
-                : TypeNode::fromType(func.getType()),
-            isInv,
-            TypeNode::fromType(sygusType));
-    
+        ->toStreamCmdSynthFun(ss,
+                              id,
+                              vars,
+                              func.getType().isFunction()
+                                  ? func.getType().getRangeType()
+                                  : func.getType(),
+                              isInv,
+                              sygusType);
+
     // must print it on the standard output channel since it is not possible
     // to print anything except for commands with Dump.
     std::ostream& out = *d_options.getOut();
     out << ss.str() << std::endl;
   }
-
-  // sygus conjecture is now stale
-  setSygusConjectureStale();
 }
 
-void SmtEngine::assertSygusConstraint(const Node& constraint)
+void SmtEngine::assertSygusConstraint(Node constraint)
 {
   SmtScope smts(this);
   finishInit();
-  d_private->d_sygusConstraints.push_back(constraint);
-
-  Trace("smt") << "SmtEngine::assertSygusConstrant: " << constraint << "\n";
+  d_sygusSolver->assertSygusConstraint(constraint);
   Dump("raw-benchmark") << SygusConstraintCommand(constraint.toExpr());
-  // sygus conjecture is now stale
-  setSygusConjectureStale();
 }
 
-void SmtEngine::assertSygusInvConstraint(const Expr& inv,
-                                         const Expr& pre,
-                                         const Expr& trans,
-                                         const Expr& post)
+void SmtEngine::assertSygusInvConstraint(Node inv,
+                                         Node pre,
+                                         Node trans,
+                                         Node post)
 {
   SmtScope smts(this);
   finishInit();
-  // build invariant constraint
-
-  // get variables (regular and their respective primed versions)
-  std::vector<Node> terms, vars, primed_vars;
-  terms.push_back(Node::fromExpr(inv));
-  terms.push_back(Node::fromExpr(pre));
-  terms.push_back(Node::fromExpr(trans));
-  terms.push_back(Node::fromExpr(post));
-  // variables are built based on the invariant type
-  FunctionType t = static_cast<FunctionType>(inv.getType());
-  std::vector<Type> argTypes = t.getArgTypes();
-  for (const Type& ti : argTypes)
-  {
-    TypeNode tn = TypeNode::fromType(ti);
-    vars.push_back(d_nodeManager->mkBoundVar(tn));
-    d_private->d_sygusVars.push_back(vars.back());
-    std::stringstream ss;
-    ss << vars.back() << "'";
-    primed_vars.push_back(d_nodeManager->mkBoundVar(ss.str(), tn));
-    d_private->d_sygusVars.push_back(primed_vars.back());
-  }
-
-  // make relevant terms; 0 -> Inv, 1 -> Pre, 2 -> Trans, 3 -> Post
-  for (unsigned i = 0; i < 4; ++i)
-  {
-    Node op = terms[i];
-    Trace("smt-debug") << "Make inv-constraint term #" << i << " : " << op
-                       << " with type " << op.getType() << "...\n";
-    std::vector<Node> children;
-    children.push_back(op);
-    // transition relation applied over both variable lists
-    if (i == 2)
-    {
-      children.insert(children.end(), vars.begin(), vars.end());
-      children.insert(children.end(), primed_vars.begin(), primed_vars.end());
-    }
-    else
-    {
-      children.insert(children.end(), vars.begin(), vars.end());
-    }
-    terms[i] = d_nodeManager->mkNode(kind::APPLY_UF, children);
-    // make application of Inv on primed variables
-    if (i == 0)
-    {
-      children.clear();
-      children.push_back(op);
-      children.insert(children.end(), primed_vars.begin(), primed_vars.end());
-      terms.push_back(d_nodeManager->mkNode(kind::APPLY_UF, children));
-    }
-  }
-  // make constraints
-  std::vector<Node> conj;
-  conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, terms[1], terms[0]));
-  Node term0_and_2 = d_nodeManager->mkNode(kind::AND, terms[0], terms[2]);
-  conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, term0_and_2, terms[4]));
-  conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, terms[0], terms[3]));
-  Node constraint = d_nodeManager->mkNode(kind::AND, conj);
-
-  d_private->d_sygusConstraints.push_back(constraint);
-
-  Trace("smt") << "SmtEngine::assertSygusInvConstrant: " << constraint << "\n";
-  Dump("raw-benchmark") << SygusInvConstraintCommand(inv, pre, trans, post);
-  // sygus conjecture is now stale
-  setSygusConjectureStale();
+  d_sygusSolver->assertSygusInvConstraint(inv, pre, trans, post);
+  Dump("raw-benchmark") << SygusInvConstraintCommand(
+      inv.toExpr(), pre.toExpr(), trans.toExpr(), post.toExpr());
 }
 
 Result SmtEngine::checkSynth()
 {
   SmtScope smts(this);
-
-  if (options::incrementalSolving())
-  {
-    // TODO (project #7)
-    throw ModalException(
-        "Cannot make check-synth commands when incremental solving is enabled");
-  }
-  std::vector<Node> query;
-  if (d_private->d_sygusConjectureStale)
-  {
-    // build synthesis conjecture from asserted constraints and declared
-    // variables/functions
-    Node sygusVar =
-        d_nodeManager->mkSkolem("sygus", d_nodeManager->booleanType());
-    Node inst_attr = d_nodeManager->mkNode(kind::INST_ATTRIBUTE, sygusVar);
-    Node sygusAttr = d_nodeManager->mkNode(kind::INST_PATTERN_LIST, inst_attr);
-    std::vector<Node> bodyv;
-    Trace("smt") << "Sygus : Constructing sygus constraint...\n";
-    unsigned n_constraints = d_private->d_sygusConstraints.size();
-    Node body = n_constraints == 0
-                    ? d_nodeManager->mkConst(true)
-                    : (n_constraints == 1
-                           ? d_private->d_sygusConstraints[0]
-                           : d_nodeManager->mkNode(
-                               kind::AND, d_private->d_sygusConstraints));
-    body = body.notNode();
-    Trace("smt") << "...constructed sygus constraint " << body << std::endl;
-    if (!d_private->d_sygusVars.empty())
-    {
-      Node boundVars =
-          d_nodeManager->mkNode(kind::BOUND_VAR_LIST, d_private->d_sygusVars);
-      body = d_nodeManager->mkNode(kind::EXISTS, boundVars, body);
-      Trace("smt") << "...constructed exists " << body << std::endl;
-    }
-    if (!d_private->d_sygusFunSymbols.empty())
-    {
-      Node boundVars = d_nodeManager->mkNode(kind::BOUND_VAR_LIST,
-                                             d_private->d_sygusFunSymbols);
-      body = d_nodeManager->mkNode(kind::FORALL, boundVars, body, sygusAttr);
-    }
-    Trace("smt") << "...constructed forall " << body << std::endl;
-
-    // set attribute for synthesis conjecture
-    setUserAttribute("sygus", sygusVar.toExpr(), {}, "");
-
-    Trace("smt") << "Check synthesis conjecture: " << body << std::endl;
-    Dump("raw-benchmark") << CheckSynthCommand();
-
-    d_private->d_sygusConjectureStale = false;
-
-    // TODO (project #7): if incremental, we should push a context and assert
-    query.push_back(body);
-  }
-
-  Result r = checkSatInternal(query, true, false);
-
-  // Check that synthesis solutions satisfy the conjecture
-  if (options::checkSynthSol()
-      && r.asSatisfiabilityResult().isSat() == Result::UNSAT)
-  {
-    checkSynthSolution();
-  }
-  return r;
+  finishInit();
+  return d_sygusSolver->checkSynth(*d_asserts);
 }
 
 /*
@@ -2032,160 +1817,12 @@ void SmtEngine::checkModel(bool hardFailure) {
   Notice() << "SmtEngine::checkModel(): all assertions checked out OK !" << endl;
 }
 
-void SmtEngine::checkSynthSolution()
-{
-  NodeManager* nm = NodeManager::currentNM();
-  Notice() << "SmtEngine::checkSynthSolution(): checking synthesis solution" << endl;
-  std::map<Node, std::map<Node, Node>> sol_map;
-  TheoryEngine* te = getTheoryEngine();
-  Assert(te != nullptr);
-  /* Get solutions and build auxiliary vectors for substituting */
-  if (!te->getSynthSolutions(sol_map))
-  {
-    InternalError() << "SmtEngine::checkSynthSolution(): No solution to check!";
-    return;
-  }
-  if (sol_map.empty())
-  {
-    InternalError() << "SmtEngine::checkSynthSolution(): Got empty solution!";
-    return;
-  }
-  Trace("check-synth-sol") << "Got solution map:\n";
-  // the set of synthesis conjectures in our assertions
-  std::unordered_set<Node, NodeHashFunction> conjs;
-  // For each of the above conjectures, the functions-to-synthesis and their
-  // solutions. This is used as a substitution below.
-  std::map<Node, std::vector<Node>> fvarMap;
-  std::map<Node, std::vector<Node>> fsolMap;
-  for (const std::pair<const Node, std::map<Node, Node>>& cmap : sol_map)
-  {
-    Trace("check-synth-sol") << "For conjecture " << cmap.first << ":\n";
-    conjs.insert(cmap.first);
-    std::vector<Node>& fvars = fvarMap[cmap.first];
-    std::vector<Node>& fsols = fsolMap[cmap.first];
-    for (const std::pair<const Node, Node>& pair : cmap.second)
-    {
-      Trace("check-synth-sol")
-          << "  " << pair.first << " --> " << pair.second << "\n";
-      fvars.push_back(pair.first);
-      fsols.push_back(pair.second);
-    }
-  }
-  Trace("check-synth-sol") << "Starting new SMT Engine\n";
-  /* Start new SMT engine to check solutions */
-  SmtEngine solChecker(d_exprManager, &d_options);
-  solChecker.setIsInternalSubsolver();
-  solChecker.setLogic(getLogicInfo());
-  solChecker.getOptions().set(options::checkSynthSol, false);
-  solChecker.getOptions().set(options::sygusRecFun, false);
-
-  Trace("check-synth-sol") << "Retrieving assertions\n";
-  // Build conjecture from original assertions
-  context::CDList<Node>* al = d_asserts->getAssertionList();
-  if (al == nullptr)
-  {
-    Trace("check-synth-sol") << "No assertions to check\n";
-    return;
-  }
-  // auxiliary assertions
-  std::vector<Node> auxAssertions;
-  // expand definitions cache
-  std::unordered_map<Node, Node, NodeHashFunction> cache;
-  for (const Node& assertion : *al)
-  {
-    Notice() << "SmtEngine::checkSynthSolution(): checking assertion "
-             << assertion << endl;
-    Trace("check-synth-sol") << "Retrieving assertion " << assertion << "\n";
-    // Apply any define-funs from the problem.
-    Node n = d_pp->expandDefinitions(assertion, cache);
-    Notice() << "SmtEngine::checkSynthSolution(): -- expands to " << n << endl;
-    Trace("check-synth-sol") << "Expanded assertion " << n << "\n";
-    if (conjs.find(n) == conjs.end())
-    {
-      Trace("check-synth-sol") << "It is an auxiliary assertion\n";
-      auxAssertions.push_back(n);
-    }
-    else
-    {
-      Trace("check-synth-sol") << "It is a synthesis conjecture\n";
-    }
-  }
-  // check all conjectures
-  for (const Node& conj : conjs)
-  {
-    // get the solution for this conjecture
-    std::vector<Node>& fvars = fvarMap[conj];
-    std::vector<Node>& fsols = fsolMap[conj];
-    // Apply solution map to conjecture body
-    Node conjBody;
-    /* Whether property is quantifier free */
-    if (conj[1].getKind() != kind::EXISTS)
-    {
-      conjBody = conj[1].substitute(
-          fvars.begin(), fvars.end(), fsols.begin(), fsols.end());
-    }
-    else
-    {
-      conjBody = conj[1][1].substitute(
-          fvars.begin(), fvars.end(), fsols.begin(), fsols.end());
-
-      /* Skolemize property */
-      std::vector<Node> vars, skos;
-      for (unsigned j = 0, size = conj[1][0].getNumChildren(); j < size; ++j)
-      {
-        vars.push_back(conj[1][0][j]);
-        std::stringstream ss;
-        ss << "sk_" << j;
-        skos.push_back(nm->mkSkolem(ss.str(), conj[1][0][j].getType()));
-        Trace("check-synth-sol") << "\tSkolemizing " << conj[1][0][j] << " to "
-                                 << skos.back() << "\n";
-      }
-      conjBody = conjBody.substitute(
-          vars.begin(), vars.end(), skos.begin(), skos.end());
-    }
-    Notice() << "SmtEngine::checkSynthSolution(): -- body substitutes to "
-             << conjBody << endl;
-    Trace("check-synth-sol") << "Substituted body of assertion to " << conjBody
-                             << "\n";
-    solChecker.assertFormula(conjBody);
-    // Assert all auxiliary assertions. This may include recursive function
-    // definitions that were added as assertions to the sygus problem.
-    for (const Node& a : auxAssertions)
-    {
-      solChecker.assertFormula(a);
-    }
-    Result r = solChecker.checkSat();
-    Notice() << "SmtEngine::checkSynthSolution(): result is " << r << endl;
-    Trace("check-synth-sol") << "Satsifiability check: " << r << "\n";
-    if (r.asSatisfiabilityResult().isUnknown())
-    {
-      InternalError() << "SmtEngine::checkSynthSolution(): could not check "
-                         "solution, result "
-                         "unknown.";
-    }
-    else if (r.asSatisfiabilityResult().isSat())
-    {
-      InternalError()
-          << "SmtEngine::checkSynthSolution(): produced solution leads to "
-             "satisfiable negated conjecture.";
-    }
-    solChecker.resetAssertions();
-  }
-}
-
 void SmtEngine::checkInterpol(Expr interpol,
                               const std::vector<Expr>& easserts,
                               const Node& conj)
 {
 }
 
-void SmtEngine::checkAbduct(Node a)
-{
-  Assert(a.getType().isBoolean());
-  // check it with the abduction solver
-  return d_abductSolver->checkAbduct(a);
-}
-
 // TODO(#1108): Simplify the error reporting of this method.
 UnsatCore SmtEngine::getUnsatCore() {
   Trace("smt") << "SMT getUnsatCore()" << endl;
@@ -2248,26 +1885,11 @@ void SmtEngine::printSynthSolution( std::ostream& out ) {
   te->printSynthSolution(out);
 }
 
-bool SmtEngine::getSynthSolutions(std::map<Expr, Expr>& sol_map)
+bool SmtEngine::getSynthSolutions(std::map<Node, Node>& solMap)
 {
   SmtScope smts(this);
   finishInit();
-  std::map<Node, std::map<Node, Node>> sol_mapn;
-  TheoryEngine* te = getTheoryEngine();
-  Assert(te != nullptr);
-  // fail if the theory engine does not have synthesis solutions
-  if (!te->getSynthSolutions(sol_mapn))
-  {
-    return false;
-  }
-  for (std::pair<const Node, std::map<Node, Node>>& cs : sol_mapn)
-  {
-    for (std::pair<const Node, Node>& s : cs.second)
-    {
-      sol_map[s.first.toExpr()] = s.second.toExpr();
-    }
-  }
-  return true;
+  return d_sygusSolver->getSynthSolutions(solMap);
 }
 
 Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict)
@@ -2701,15 +2323,4 @@ ResourceManager* SmtEngine::getResourceManager()
 
 DumpManager* SmtEngine::getDumpManager() { return d_dumpm.get(); }
 
-void SmtEngine::setSygusConjectureStale()
-{
-  if (d_private->d_sygusConjectureStale)
-  {
-    // already stale
-    return;
-  }
-  d_private->d_sygusConjectureStale = true;
-  // TODO (project #7): if incremental, we should pop a context
-}
-
 }/* CVC4 namespace */
index 1c71e371e24764adce8d088bc83338813d834c95..a6688578df7afefb6798b99357ba2290cb4c35aa 100644 (file)
@@ -104,6 +104,7 @@ class OptionsManager;
 class Preprocessor;
 /** Subsolvers */
 class SmtSolver;
+class SygusSolver;
 class AbductionSolver;
 /**
  * Representation of a defined function.  We keep these around in
@@ -114,7 +115,6 @@ class AbductionSolver;
 class DefinedFunction;
 
 struct SmtEngineStatistics;
-class SmtEnginePrivate;
 class SmtScope;
 class ProcessAssertions;
 
@@ -146,7 +146,6 @@ class CVC4_PUBLIC SmtEngine
   friend class ::CVC4::api::Solver;
   // TODO (Issue #1096): Remove this friend relationship.
   friend class ::CVC4::preprocessing::PreprocessingPassContext;
-  friend class ::CVC4::smt::SmtEnginePrivate;
   friend class ::CVC4::smt::SmtEngineState;
   friend class ::CVC4::smt::SmtScope;
   friend class ::CVC4::smt::ProcessAssertions;
@@ -417,22 +416,18 @@ class CVC4_PUBLIC SmtEngine
   /*---------------------------- sygus commands  ---------------------------*/
 
   /**
-   * Add variable declaration.
+   * Add sygus variable declaration.
    *
    * Declared SyGuS variables may be used in SyGuS constraints, in which they
    * are assumed to be universally quantified.
-   */
-  void declareSygusVar(const std::string& id, Expr var, Type type);
-
-  /**
-   * Add a function variable declaration.
    *
-   * Is SyGuS semantics declared functions are treated in the same manner as
+   * In SyGuS semantics, declared functions are treated in the same manner as
    * declared variables, i.e. as universally quantified (function) variables
    * which can occur in the SyGuS constraints that compose the conjecture to
-   * which a function is being synthesized.
+   * which a function is being synthesized. Thus declared functions should use
+   * this method as well.
    */
-  void declareSygusFunctionVar(const std::string& id, Expr var, Type type);
+  void declareSygusVar(const std::string& id, Node var, TypeNode type);
 
   /**
    * Add a function-to-synthesize declaration.
@@ -450,13 +445,13 @@ class CVC4_PUBLIC SmtEngine
    * corresponding to this declaration, so that it can be properly printed.
    */
   void declareSynthFun(const std::string& id,
-                       Expr func,
-                       Type type,
+                       Node func,
+                       TypeNode type,
                        bool isInv,
-                       const std::vector<Expr>& vars);
+                       const std::vector<Node>& vars);
 
   /** Add a regular sygus constraint.*/
-  void assertSygusConstraint(const Node& constraint);
+  void assertSygusConstraint(Node constraint);
 
   /**
    * Add an invariant constraint.
@@ -473,10 +468,7 @@ class CVC4_PUBLIC SmtEngine
    * The regular and primed variables are retrieved from the declaration of the
    * invariant-to-synthesize.
    */
-  void assertSygusInvConstraint(const Expr& inv,
-                                const Expr& pre,
-                                const Expr& trans,
-                                const Expr& post);
+  void assertSygusInvConstraint(Node inv, Node pre, Node trans, Node post);
   /**
    * Assert a synthesis conjecture to the current context and call
    * check().  Returns sat, unsat, or unknown result.
@@ -576,18 +568,18 @@ class CVC4_PUBLIC SmtEngine
    * This method returns true if we are in a state immediately preceeded by
    * a successful call to checkSynth.
    *
-   * This method adds entries to sol_map that map functions-to-synthesize with
+   * This method adds entries to solMap that map functions-to-synthesize with
    * their solutions, for all active conjectures. This should be called
    * immediately after the solver answers unsat for sygus input.
    *
    * Specifically, given a sygus conjecture of the form
    *   exists x1...xn. forall y1...yn. P( x1...xn, y1...yn )
    * where x1...xn are second order bound variables, we map each xi to
-   * lambda term in sol_map such that
-   *    forall y1...yn. P( sol_map[x1]...sol_map[xn], y1...yn )
+   * lambda term in solMap such that
+   *    forall y1...yn. P( solMap[x1]...solMap[xn], y1...yn )
    * is a valid formula.
    */
-  bool getSynthSolutions(std::map<Expr, Expr>& sol_map);
+  bool getSynthSolutions(std::map<Node, Node>& solMap);
 
   /**
    * Do quantifier elimination.
@@ -951,16 +943,6 @@ class CVC4_PUBLIC SmtEngine
    */
   void checkModel(bool hardFailure = true);
 
-  /**
-   * Check that a solution to a synthesis conjecture is indeed a solution.
-   *
-   * The check is made by determining if the negation of the synthesis
-   * conjecture in which the functions-to-synthesize have been replaced by the
-   * synthesized solutions, which is a quantifier-free formula, is
-   * unsatisfiable. If not, then the found solutions are wrong.
-   */
-  void checkSynthSolution();
-
   /**
    * Check that a solution to an interpolation problem is indeed a solution.
    *
@@ -972,16 +954,6 @@ class CVC4_PUBLIC SmtEngine
                      const std::vector<Expr>& easserts,
                      const Node& conj);
 
-  /**
-   * Check that a solution to an abduction conjecture is indeed a solution.
-   *
-   * The check is made by determining that the assertions conjoined with the
-   * solution to the abduction problem (a) is SAT, and the assertions conjoined
-   * with the abduct and the goal is UNSAT. If these criteria are not met, an
-   * internal error is thrown.
-   */
-  void checkAbduct(Node a);
-
   /**
    * This is called by the destructor, just before destroying the
    * PropEngine, TheoryEngine, and DecisionEngine (in that order).  It
@@ -1129,6 +1101,9 @@ class CVC4_PUBLIC SmtEngine
   /** An index of our defined functions */
   DefinedFunctionMap* d_definedFunctions;
 
+  /** The solver for sygus queries */
+  std::unique_ptr<smt::SygusSolver> d_sygusSolver;
+
   /** The solver for abduction queries */
   std::unique_ptr<smt::AbductionSolver> d_abductSolver;
   /**
@@ -1164,11 +1139,6 @@ class CVC4_PUBLIC SmtEngine
    */
   std::map<std::string, Integer> d_commandVerbosity;
 
-  /**
-   * A private utility class to SmtEngine.
-   */
-  std::unique_ptr<smt::SmtEnginePrivate> d_private;
-
   std::unique_ptr<StatisticsRegistry> d_statisticsRegistry;
 
   std::unique_ptr<smt::SmtEngineStatistics> d_stats;
@@ -1195,23 +1165,6 @@ class CVC4_PUBLIC SmtEngine
    * or another SmtEngine is created.
    */
   std::unique_ptr<smt::SmtScope> d_scope;
-  /*---------------------------- sygus commands  ---------------------------*/
-
-  /**
-   * Set sygus conjecture is stale. The sygus conjecture is stale if either:
-   * (1) no sygus conjecture has been added as an assertion to this SMT engine,
-   * (2) there is a sygus conjecture that has been added as an assertion
-   * internally to this SMT engine, and there have been further calls such that
-   * the asserted conjecture is no longer up-to-date.
-   *
-   * This method should be called when new sygus constraints are asserted and
-   * when functions-to-synthesize are declared. This function pops a user
-   * context if we are in incremental mode and the sygus conjecture was
-   * previously not stale.
-   */
-  void setSygusConjectureStale();
-
-  /*------------------------- end of sygus commands ------------------------*/
 }; /* class SmtEngine */
 
 /* -------------------------------------------------------------------------- */
diff --git a/src/smt/sygus_solver.cpp b/src/smt/sygus_solver.cpp
new file mode 100644 (file)
index 0000000..0fc63d1
--- /dev/null
@@ -0,0 +1,402 @@
+/*********************                                                        */
+/*! \file sygus_solver.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The solver for sygus queries
+ **/
+
+#include "smt/sygus_solver.h"
+
+#include "expr/dtype.h"
+#include "options/quantifiers_options.h"
+#include "options/smt_options.h"
+#include "smt/preprocessor.h"
+#include "smt/smt_solver.h"
+#include "theory/smt_engine_subsolver.h"
+#include "theory/theory_engine.h"
+
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace smt {
+
+SygusSolver::SygusSolver(SmtSolver& sms,
+                         Preprocessor& pp,
+                         context::UserContext* u)
+    : d_smtSolver(sms), d_pp(pp), d_sygusConjectureStale(u, true)
+{
+}
+
+SygusSolver::~SygusSolver() {}
+
+void SygusSolver::declareSygusVar(const std::string& id,
+                                  Node var,
+                                  TypeNode type)
+{
+  Trace("smt") << "SygusSolver::declareSygusVar: " << id << " " << var << " "
+               << type << "\n";
+  Assert(var.getType() == type);
+  d_sygusVars.push_back(var);
+  // don't need to set that the conjecture is stale
+}
+
+void SygusSolver::declareSynthFun(const std::string& id,
+                                  Node fn,
+                                  TypeNode sygusType,
+                                  bool isInv,
+                                  const std::vector<Node>& vars)
+{
+  Trace("smt") << "SygusSolver::declareSynthFun: " << id << "\n";
+  NodeManager* nm = NodeManager::currentNM();
+  TheoryEngine* te = d_smtSolver.getTheoryEngine();
+  Assert(te != nullptr);
+  d_sygusFunSymbols.push_back(fn);
+  if (!vars.empty())
+  {
+    Node bvl = nm->mkNode(BOUND_VAR_LIST, vars);
+    std::vector<Node> attr_val_bvl;
+    attr_val_bvl.push_back(bvl);
+    te->setUserAttribute("sygus-synth-fun-var-list", fn, attr_val_bvl, "");
+  }
+  // whether sygus type encodes syntax restrictions
+  if (sygusType.isDatatype() && sygusType.getDType().isSygus())
+  {
+    Node sym = nm->mkBoundVar("sfproxy", sygusType);
+    std::vector<Node> attr_value;
+    attr_value.push_back(sym);
+    te->setUserAttribute("sygus-synth-grammar", fn, attr_value, "");
+  }
+
+  // sygus conjecture is now stale
+  setSygusConjectureStale();
+}
+
+void SygusSolver::assertSygusConstraint(Node constraint)
+{
+  Trace("smt") << "SygusSolver::assertSygusConstrant: " << constraint << "\n";
+  d_sygusConstraints.push_back(constraint);
+
+  // sygus conjecture is now stale
+  setSygusConjectureStale();
+}
+
+void SygusSolver::assertSygusInvConstraint(Node inv,
+                                           Node pre,
+                                           Node trans,
+                                           Node post)
+{
+  Trace("smt") << "SygusSolver::assertSygusInvConstrant: " << inv << " " << pre
+               << " " << trans << " " << post << "\n";
+  // build invariant constraint
+
+  // get variables (regular and their respective primed versions)
+  std::vector<Node> terms;
+  std::vector<Node> vars;
+  std::vector<Node> primed_vars;
+  terms.push_back(inv);
+  terms.push_back(pre);
+  terms.push_back(trans);
+  terms.push_back(post);
+  // variables are built based on the invariant type
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<TypeNode> argTypes = inv.getType().getArgTypes();
+  for (const TypeNode& tn : argTypes)
+  {
+    vars.push_back(nm->mkBoundVar(tn));
+    d_sygusVars.push_back(vars.back());
+    std::stringstream ss;
+    ss << vars.back() << "'";
+    primed_vars.push_back(nm->mkBoundVar(ss.str(), tn));
+    d_sygusVars.push_back(primed_vars.back());
+  }
+
+  // make relevant terms; 0 -> Inv, 1 -> Pre, 2 -> Trans, 3 -> Post
+  for (unsigned i = 0; i < 4; ++i)
+  {
+    Node op = terms[i];
+    Trace("smt-debug") << "Make inv-constraint term #" << i << " : " << op
+                       << " with type " << op.getType() << "...\n";
+    std::vector<Node> children;
+    children.push_back(op);
+    // transition relation applied over both variable lists
+    if (i == 2)
+    {
+      children.insert(children.end(), vars.begin(), vars.end());
+      children.insert(children.end(), primed_vars.begin(), primed_vars.end());
+    }
+    else
+    {
+      children.insert(children.end(), vars.begin(), vars.end());
+    }
+    terms[i] = nm->mkNode(APPLY_UF, children);
+    // make application of Inv on primed variables
+    if (i == 0)
+    {
+      children.clear();
+      children.push_back(op);
+      children.insert(children.end(), primed_vars.begin(), primed_vars.end());
+      terms.push_back(nm->mkNode(APPLY_UF, children));
+    }
+  }
+  // make constraints
+  std::vector<Node> conj;
+  conj.push_back(nm->mkNode(IMPLIES, terms[1], terms[0]));
+  Node term0_and_2 = nm->mkNode(AND, terms[0], terms[2]);
+  conj.push_back(nm->mkNode(IMPLIES, term0_and_2, terms[4]));
+  conj.push_back(nm->mkNode(IMPLIES, terms[0], terms[3]));
+  Node constraint = nm->mkNode(AND, conj);
+
+  d_sygusConstraints.push_back(constraint);
+
+  // sygus conjecture is now stale
+  setSygusConjectureStale();
+}
+
+Result SygusSolver::checkSynth(Assertions& as)
+{
+  if (options::incrementalSolving())
+  {
+    // TODO (project #7)
+    throw ModalException(
+        "Cannot make check-synth commands when incremental solving is enabled");
+  }
+  Trace("smt") << "SygusSolver::checkSynth" << std::endl;
+  std::vector<Node> query;
+  if (d_sygusConjectureStale)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    // build synthesis conjecture from asserted constraints and declared
+    // variables/functions
+    Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
+    Node inst_attr = nm->mkNode(INST_ATTRIBUTE, sygusVar);
+    Node sygusAttr = nm->mkNode(INST_PATTERN_LIST, inst_attr);
+    std::vector<Node> bodyv;
+    Trace("smt") << "Sygus : Constructing sygus constraint...\n";
+    size_t nconstraints = d_sygusConstraints.size();
+    Node body = nconstraints == 0
+                    ? nm->mkConst(true)
+                    : (nconstraints == 1 ? d_sygusConstraints[0]
+                                         : nm->mkNode(AND, d_sygusConstraints));
+    body = body.notNode();
+    Trace("smt") << "...constructed sygus constraint " << body << std::endl;
+    if (!d_sygusVars.empty())
+    {
+      Node boundVars = nm->mkNode(BOUND_VAR_LIST, d_sygusVars);
+      body = nm->mkNode(EXISTS, boundVars, body);
+      Trace("smt") << "...constructed exists " << body << std::endl;
+    }
+    if (!d_sygusFunSymbols.empty())
+    {
+      Node boundVars = nm->mkNode(BOUND_VAR_LIST, d_sygusFunSymbols);
+      body = nm->mkNode(FORALL, boundVars, body, sygusAttr);
+    }
+    Trace("smt") << "...constructed forall " << body << std::endl;
+
+    // set attribute for synthesis conjecture
+    TheoryEngine* te = d_smtSolver.getTheoryEngine();
+    te->setUserAttribute("sygus", sygusVar, {}, "");
+
+    Trace("smt") << "Check synthesis conjecture: " << body << std::endl;
+    Dump("raw-benchmark") << CheckSynthCommand();
+
+    d_sygusConjectureStale = false;
+
+    // TODO (project #7): if incremental, we should push a context and assert
+    query.push_back(body);
+  }
+
+  Result r = d_smtSolver.checkSatisfiability(as, query, false, false);
+
+  // Check that synthesis solutions satisfy the conjecture
+  if (options::checkSynthSol()
+      && r.asSatisfiabilityResult().isSat() == Result::UNSAT)
+  {
+    checkSynthSolution(as);
+  }
+  return r;
+}
+
+bool SygusSolver::getSynthSolutions(std::map<Node, Node>& sol_map)
+{
+  Trace("smt") << "SygusSolver::getSynthSolutions" << std::endl;
+  std::map<Node, std::map<Node, Node>> sol_mapn;
+  // fail if the theory engine does not have synthesis solutions
+  TheoryEngine* te = d_smtSolver.getTheoryEngine();
+  Assert(te != nullptr);
+  if (!te->getSynthSolutions(sol_mapn))
+  {
+    return false;
+  }
+  for (std::pair<const Node, std::map<Node, Node>>& cs : sol_mapn)
+  {
+    for (std::pair<const Node, Node>& s : cs.second)
+    {
+      sol_map[s.first] = s.second;
+    }
+  }
+  return true;
+}
+
+void SygusSolver::checkSynthSolution(Assertions& as)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Notice() << "SygusSolver::checkSynthSolution(): checking synthesis solution"
+           << std::endl;
+  std::map<Node, std::map<Node, Node>> sol_map;
+  // Get solutions and build auxiliary vectors for substituting
+  TheoryEngine* te = d_smtSolver.getTheoryEngine();
+  if (!te->getSynthSolutions(sol_map))
+  {
+    InternalError()
+        << "SygusSolver::checkSynthSolution(): No solution to check!";
+    return;
+  }
+  if (sol_map.empty())
+  {
+    InternalError() << "SygusSolver::checkSynthSolution(): Got empty solution!";
+    return;
+  }
+  Trace("check-synth-sol") << "Got solution map:\n";
+  // the set of synthesis conjectures in our assertions
+  std::unordered_set<Node, NodeHashFunction> conjs;
+  // For each of the above conjectures, the functions-to-synthesis and their
+  // solutions. This is used as a substitution below.
+  std::map<Node, std::vector<Node>> fvarMap;
+  std::map<Node, std::vector<Node>> fsolMap;
+  for (const std::pair<const Node, std::map<Node, Node>>& cmap : sol_map)
+  {
+    Trace("check-synth-sol") << "For conjecture " << cmap.first << ":\n";
+    conjs.insert(cmap.first);
+    std::vector<Node>& fvars = fvarMap[cmap.first];
+    std::vector<Node>& fsols = fsolMap[cmap.first];
+    for (const std::pair<const Node, Node>& pair : cmap.second)
+    {
+      Trace("check-synth-sol")
+          << "  " << pair.first << " --> " << pair.second << "\n";
+      fvars.push_back(pair.first);
+      fsols.push_back(pair.second);
+    }
+  }
+  Trace("check-synth-sol") << "Starting new SMT Engine\n";
+
+  Trace("check-synth-sol") << "Retrieving assertions\n";
+  // Build conjecture from original assertions
+  context::CDList<Node>* alist = as.getAssertionList();
+  if (alist == nullptr)
+  {
+    Trace("check-synth-sol") << "No assertions to check\n";
+    return;
+  }
+  // auxiliary assertions
+  std::vector<Node> auxAssertions;
+  // expand definitions cache
+  std::unordered_map<Node, Node, NodeHashFunction> cache;
+  for (Node assertion : *alist)
+  {
+    Notice() << "SygusSolver::checkSynthSolution(): checking assertion "
+             << assertion << std::endl;
+    Trace("check-synth-sol") << "Retrieving assertion " << assertion << "\n";
+    // Apply any define-funs from the problem.
+    Node n = d_pp.expandDefinitions(assertion, cache);
+    Notice() << "SygusSolver::checkSynthSolution(): -- expands to " << n
+             << std::endl;
+    Trace("check-synth-sol") << "Expanded assertion " << n << "\n";
+    if (conjs.find(n) == conjs.end())
+    {
+      Trace("check-synth-sol") << "It is an auxiliary assertion\n";
+      auxAssertions.push_back(n);
+    }
+    else
+    {
+      Trace("check-synth-sol") << "It is a synthesis conjecture\n";
+    }
+  }
+  // check all conjectures
+  for (Node conj : conjs)
+  {
+    // Start new SMT engine to check solutions
+    std::unique_ptr<SmtEngine> solChecker;
+    initializeSubsolver(solChecker);
+    solChecker->getOptions().set(options::checkSynthSol, false);
+    solChecker->getOptions().set(options::sygusRecFun, false);
+    // get the solution for this conjecture
+    std::vector<Node>& fvars = fvarMap[conj];
+    std::vector<Node>& fsols = fsolMap[conj];
+    // Apply solution map to conjecture body
+    Node conjBody;
+    /* Whether property is quantifier free */
+    if (conj[1].getKind() != EXISTS)
+    {
+      conjBody = conj[1].substitute(
+          fvars.begin(), fvars.end(), fsols.begin(), fsols.end());
+    }
+    else
+    {
+      conjBody = conj[1][1].substitute(
+          fvars.begin(), fvars.end(), fsols.begin(), fsols.end());
+
+      /* Skolemize property */
+      std::vector<Node> vars, skos;
+      for (unsigned j = 0, size = conj[1][0].getNumChildren(); j < size; ++j)
+      {
+        vars.push_back(conj[1][0][j]);
+        std::stringstream ss;
+        ss << "sk_" << j;
+        skos.push_back(nm->mkSkolem(ss.str(), conj[1][0][j].getType()));
+        Trace("check-synth-sol") << "\tSkolemizing " << conj[1][0][j] << " to "
+                                 << skos.back() << "\n";
+      }
+      conjBody = conjBody.substitute(
+          vars.begin(), vars.end(), skos.begin(), skos.end());
+    }
+    Notice() << "SygusSolver::checkSynthSolution(): -- body substitutes to "
+             << conjBody << std::endl;
+    Trace("check-synth-sol")
+        << "Substituted body of assertion to " << conjBody << "\n";
+    solChecker->assertFormula(conjBody);
+    // Assert all auxiliary assertions. This may include recursive function
+    // definitions that were added as assertions to the sygus problem.
+    for (Node a : auxAssertions)
+    {
+      solChecker->assertFormula(a);
+    }
+    Result r = solChecker->checkSat();
+    Notice() << "SygusSolver::checkSynthSolution(): result is " << r
+             << std::endl;
+    Trace("check-synth-sol") << "Satsifiability check: " << r << "\n";
+    if (r.asSatisfiabilityResult().isUnknown())
+    {
+      InternalError() << "SygusSolver::checkSynthSolution(): could not check "
+                         "solution, result "
+                         "unknown.";
+    }
+    else if (r.asSatisfiabilityResult().isSat())
+    {
+      InternalError()
+          << "SygusSolver::checkSynthSolution(): produced solution leads to "
+             "satisfiable negated conjecture.";
+    }
+  }
+}
+
+void SygusSolver::setSygusConjectureStale()
+{
+  if (d_sygusConjectureStale)
+  {
+    // already stale
+    return;
+  }
+  d_sygusConjectureStale = true;
+  // TODO (project #7): if incremental, we should pop a context
+}
+
+}  // namespace smt
+}  // namespace CVC4
diff --git a/src/smt/sygus_solver.h b/src/smt/sygus_solver.h
new file mode 100644 (file)
index 0000000..468535d
--- /dev/null
@@ -0,0 +1,182 @@
+/*********************                                                        */
+/*! \file sygus_solver.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The solver for sygus queries
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__SMT__SYGUS_SOLVER_H
+#define CVC4__SMT__SYGUS_SOLVER_H
+
+#include "context/cdo.h"
+#include "expr/node.h"
+#include "expr/type_node.h"
+#include "smt/assertions.h"
+#include "util/result.h"
+
+namespace CVC4 {
+namespace smt {
+
+class Preprocessor;
+class SmtSolver;
+
+/**
+ * A solver for sygus queries.
+ *
+ * This class is responsible for responding to check-synth commands. It calls
+ * check satisfiability using an underlying SmtSolver object.
+ *
+ * It also maintains a reference to a preprocessor for implementing
+ * checkSynthSolution.
+ */
+class SygusSolver
+{
+ public:
+  SygusSolver(SmtSolver& sms, Preprocessor& pp, context::UserContext* u);
+  ~SygusSolver();
+
+  /**
+   * Add variable declaration.
+   *
+   * Declared SyGuS variables may be used in SyGuS constraints, in which they
+   * are assumed to be universally quantified.
+   *
+   * In SyGuS semantics, declared functions are treated in the same manner as
+   * declared variables, i.e. as universally quantified (function) variables
+   * which can occur in the SyGuS constraints that compose the conjecture to
+   * which a function is being synthesized. Thus declared functions should use
+   * this method as well.
+   */
+  void declareSygusVar(const std::string& id, Node var, TypeNode type);
+
+  /**
+   * Add a function-to-synthesize declaration.
+   *
+   * The given type may not correspond to the actual function type but to a
+   * datatype encoding the syntax restrictions for the
+   * function-to-synthesize. In this case this information is stored to be used
+   * during solving.
+   *
+   * vars contains the arguments of the function-to-synthesize. These variables
+   * are also stored to be used during solving.
+   *
+   * isInv determines whether the function-to-synthesize is actually an
+   * invariant. This information is necessary if we are dumping a command
+   * corresponding to this declaration, so that it can be properly printed.
+   */
+  void declareSynthFun(const std::string& id,
+                       Node func,
+                       TypeNode type,
+                       bool isInv,
+                       const std::vector<Node>& vars);
+
+  /** Add a regular sygus constraint.*/
+  void assertSygusConstraint(Node constraint);
+
+  /**
+   * Add an invariant constraint.
+   *
+   * Invariant constraints are not explicitly declared: they are given in terms
+   * of the invariant-to-synthesize, the pre condition, transition relation and
+   * post condition. The actual constraint is built based on the inputs of these
+   * place holder predicates :
+   *
+   * PRE(x) -> INV(x)
+   * INV() ^ TRANS(x, x') -> INV(x')
+   * INV(x) -> POST(x)
+   *
+   * The regular and primed variables are retrieved from the declaration of the
+   * invariant-to-synthesize.
+   */
+  void assertSygusInvConstraint(Node inv, Node pre, Node trans, Node post);
+  /**
+   * Assert a synthesis conjecture to the current context and call
+   * check().  Returns sat, unsat, or unknown result.
+   *
+   * The actual synthesis conjecture is built based on the previously
+   * communicated information to this module (universal variables, defined
+   * functions, functions-to-synthesize, and which constraints compose it). The
+   * built conjecture is a higher-order formula of the form
+   *
+   * exists f1...fn . forall v1...vm . F
+   *
+   * in which f1...fn are the functions-to-synthesize, v1...vm are the declared
+   * universal variables and F is the set of declared constraints.
+   */
+  Result checkSynth(Assertions& as);
+  /**
+   * Get synth solution.
+   *
+   * This method returns true if we are in a state immediately preceeded by
+   * a successful call to checkSynth.
+   *
+   * This method adds entries to sol_map that map functions-to-synthesize with
+   * their solutions, for all active conjectures. This should be called
+   * immediately after the solver answers unsat for sygus input.
+   *
+   * Specifically, given a sygus conjecture of the form
+   *   exists x1...xn. forall y1...yn. P( x1...xn, y1...yn )
+   * where x1...xn are second order bound variables, we map each xi to
+   * lambda term in sol_map such that
+   *    forall y1...yn. P( sol_map[x1]...sol_map[xn], y1...yn )
+   * is a valid formula.
+   */
+  bool getSynthSolutions(std::map<Node, Node>& sol_map);
+
+ private:
+  /**
+   * Check that a solution to a synthesis conjecture is indeed a solution.
+   *
+   * The check is made by determining if the negation of the synthesis
+   * conjecture in which the functions-to-synthesize have been replaced by the
+   * synthesized solutions, which is a quantifier-free formula, is
+   * unsatisfiable. If not, then the found solutions are wrong.
+   */
+  void checkSynthSolution(Assertions& as);
+  /**
+   * Set sygus conjecture is stale. The sygus conjecture is stale if either:
+   * (1) no sygus conjecture has been added as an assertion to this SMT engine,
+   * (2) there is a sygus conjecture that has been added as an assertion
+   * internally to this SMT engine, and there have been further calls such that
+   * the asserted conjecture is no longer up-to-date.
+   *
+   * This method should be called when new sygus constraints are asserted and
+   * when functions-to-synthesize are declared. This function pops a user
+   * context if we are in incremental mode and the sygus conjecture was
+   * previously not stale.
+   */
+  void setSygusConjectureStale();
+  /** The SMT solver, which is used during checkSynth. */
+  SmtSolver& d_smtSolver;
+  /** The preprocessor, used for checkSynthSolution. */
+  Preprocessor& d_pp;
+  /**
+   * sygus variables declared (from "declare-var" and "declare-fun" commands)
+   *
+   * The SyGuS semantics for declared variables is that they are implicitly
+   * universally quantified in the constraints.
+   */
+  std::vector<Node> d_sygusVars;
+  /** sygus constraints */
+  std::vector<Node> d_sygusConstraints;
+  /** functions-to-synthesize */
+  std::vector<Node> d_sygusFunSymbols;
+  /**
+   * Whether we need to reconstruct the sygus conjecture.
+   */
+  context::CDO<bool> d_sygusConjectureStale;
+};
+
+}  // namespace smt
+}  // namespace CVC4
+
+#endif /* CVC4__SMT__SYGUS_SOLVER_H */
index 0d08140d307a3b772584cb22a454aecb521fd3e3..0ecd888e07efa7faf647e23bd64bb8434e34ed54 100644 (file)
@@ -274,10 +274,10 @@ void SygusInterpol::mkSygusConjecture(Node itp,
 bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
 {
   // get the synthesis solution
-  std::map<Expr, Expr> sols;
+  std::map<Node, Node> sols;
   d_subSolver->getSynthSolutions(sols);
   Assert(sols.size() == 1);
-  std::map<Expr, Expr>::iterator its = sols.find(itp.toExpr());
+  std::map<Node, Node>::iterator its = sols.find(itp);
   if (its == sols.end())
   {
     Trace("sygus-interpol")
@@ -288,7 +288,7 @@ bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
   }
   Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is "
                           << its->second << std::endl;
-  Node interpoln = Node::fromExpr(its->second);
+  Node interpoln = its->second;
   // replace back the created variables to original symbols.
   Node interpoln_reduced;
   if (interpoln.getKind() == kind::LAMBDA)
@@ -336,18 +336,17 @@ bool SygusInterpol::SolveInterpolation(const std::string& name,
   createVariables(itpGType.isNull());
   for (Node var : d_vars)
   {
-    d_subSolver->declareSygusVar(name, var.toExpr(), var.getType().toType());
+    d_subSolver->declareSygusVar(name, var, var.getType());
   }
-  std::vector<Expr> vars_empty;
+  std::vector<Node> vars_empty;
   TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj);
   Node itp = mkPredicate(name);
-  d_subSolver->declareSynthFun(
-      name, itp.toExpr(), grammarType.toType(), false, vars_empty);
+  d_subSolver->declareSynthFun(name, itp, grammarType, false, vars_empty);
   mkSygusConjecture(itp, axioms, conj);
   Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : "
                           << d_sygusConj << ", solving for "
-                          << d_sygusConj[0][0].toExpr() << std::endl;
-  d_subSolver->assertSygusConstraint(d_sygusConj.toExpr());
+                          << d_sygusConj[0][0] << std::endl;
+  d_subSolver->assertSygusConstraint(d_sygusConj);
 
   Trace("sygus-interpol") << "  SmtEngine::getInterpol check sat..."
                           << std::endl;