Refactor DeclareSygusVarCommand and SynthFunCommand to use the API. (#5334)
authorAbdalrhman Mohamed <32971963+abdoo8080@users.noreply.github.com>
Tue, 27 Oct 2020 18:19:11 +0000 (13:19 -0500)
committerGitHub <noreply@github.com>
Tue, 27 Oct 2020 18:19:11 +0000 (13:19 -0500)
This PR is part of migrating commands. DeclareSygusVarCommand and SynthFunCommand now call public API function instead of their corresponding SmtEngine functions. Those two commands don't fully initialize the solver anymore. Some operations in SygusInterpol::solveInterpolation, which creates an interpolation sub-solver, depend on the solver being fully initialized and were affected by this change. Those operations are now executed by the main solver instead of the sub-solver, which is not fully initialized by the time they are needed.

src/api/cvc4cpp.cpp
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/parser/smt2/smt2.h
src/smt/command.cpp
src/smt/smt_engine.cpp
src/theory/quantifiers/sygus/sygus_interpol.cpp
test/unit/api/solver_black.h

index bda88c53995a11888699fe40692f9c9d24691dd5..3cfeaf6cd0c1b8935cab02293413f826b865699c 100644 (file)
@@ -5591,9 +5591,6 @@ Term Solver::synthFunHelper(const std::string& symbol,
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   CVC4_API_ARG_CHECK_NOT_NULL(sort);
 
-  CVC4_API_ARG_CHECK_EXPECTED(sort.d_type->isFirstClass(), sort)
-      << "first-class codomain sort for function";
-
   std::vector<Type> varTypes;
   for (size_t i = 0, n = boundVars.size(); i < n; ++i)
   {
index 232723fc0ea49d04308f4f8836287b5b72563c91..7c1c5dc3e909d8e75a1cea303cce14bfaf8c1abd 100644 (file)
@@ -537,12 +537,12 @@ command [std::unique_ptr<CVC4::Command>* cmd]
 
 sygusCommand returns [std::unique_ptr<CVC4::Command> cmd]
 @declarations {
-  CVC4::api::Term expr, expr2;
+  CVC4::api::Term expr, expr2, fun;
   CVC4::api::Sort t, range;
   std::vector<std::string> names;
   std::vector<std::pair<std::string, CVC4::api::Sort> > sortedVarNames;
-  std::unique_ptr<Smt2::SynthFunFactory> synthFunFactory;
-  std::string name, fun;
+  std::vector<CVC4::api::Term> sygusVars;
+  std::string name;
   bool isInv;
   CVC4::api::Grammar* grammar = nullptr;
 }
@@ -552,7 +552,8 @@ sygusCommand returns [std::unique_ptr<CVC4::Command> cmd]
     { PARSER_STATE->checkUserSymbol(name); }
     sortSymbol[t,CHECK_DECLARED]
     {
-      api::Term var = PARSER_STATE->bindBoundVar(name, t);
+      api::Term var = SOLVER->mkSygusVar(t, name);
+      PARSER_STATE->defineVar(name, var);
       cmd.reset(new DeclareSygusVarCommand(name, var, t));
     }
   | /* synth-fun */
@@ -560,22 +561,36 @@ sygusCommand returns [std::unique_ptr<CVC4::Command> cmd]
       | SYNTH_INV_TOK { isInv = true; range = SOLVER->getBooleanSort(); }
     )
     { PARSER_STATE->checkThatLogicIsSet(); }
-    symbol[fun,CHECK_UNDECLARED,SYM_VARIABLE]
+    symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
     LPAREN_TOK sortedVarList[sortedVarNames] RPAREN_TOK
     ( sortSymbol[range,CHECK_DECLARED] )?
     {
-      synthFunFactory.reset(new Smt2::SynthFunFactory(
-          PARSER_STATE, fun, isInv, range, sortedVarNames));
+      PARSER_STATE->pushScope(true);
+      sygusVars = PARSER_STATE->bindBoundVars(sortedVarNames);
     }
     (
       // optionally, read the sygus grammar
       //
       // `grammar` specifies the required grammar for the function to
       // synthesize, expressed as a type
-      sygusGrammar[grammar, synthFunFactory->getSygusVars(), fun]
+      sygusGrammar[grammar, sygusVars, name]
     )?
     {
-      cmd = synthFunFactory->mkCommand(grammar);
+      Debug("parser-sygus") << "Define synth fun : " << name << std::endl;
+
+      fun = isInv ? (grammar == nullptr
+                         ? SOLVER->synthInv(name, sygusVars)
+                         : SOLVER->synthInv(name, sygusVars, *grammar))
+                  : (grammar == nullptr
+                         ? SOLVER->synthFun(name, sygusVars, range)
+                         : SOLVER->synthFun(name, sygusVars, range, *grammar));
+
+      Debug("parser-sygus") << "...read synth fun " << name << std::endl;
+      PARSER_STATE->popScope();
+      // we do not allow overloading for synth fun
+      PARSER_STATE->defineVar(name, fun);
+      cmd = std::unique_ptr<Command>(
+          new SynthFunCommand(name, fun, sygusVars, range, isInv, grammar));
     }
   | /* constraint */
     CONSTRAINT_TOK {
index 81a4bd4a6fcbfcaf1c4a4c2fa3b7296b55d0d2d3..629164593d2ee042104bca2e5c5ec020ace87b08 100644 (file)
@@ -482,50 +482,6 @@ void Smt2::resetAssertions() {
   pushScope(true);
 }
 
-Smt2::SynthFunFactory::SynthFunFactory(
-    Smt2* smt2,
-    const std::string& id,
-    bool isInv,
-    api::Sort range,
-    std::vector<std::pair<std::string, api::Sort>>& sortedVarNames)
-    : d_smt2(smt2), d_id(id), d_sort(range), d_isInv(isInv)
-{
-  if (range.isNull())
-  {
-    smt2->parseError("Must supply return type for synth-fun.");
-  }
-  if (range.isFunction())
-  {
-    smt2->parseError("Cannot use synth-fun with function return type.");
-  }
-
-  std::vector<api::Sort> varSorts;
-  for (const std::pair<std::string, api::Sort>& p : sortedVarNames)
-  {
-    varSorts.push_back(p.second);
-  }
-
-  api::Sort funSort = varSorts.empty()
-                          ? range
-                          : d_smt2->d_solver->mkFunctionSort(varSorts, range);
-
-  // we do not allow overloading for synth fun
-  d_fun = d_smt2->bindBoundVar(id, funSort);
-
-  Debug("parser-sygus") << "Define synth fun : " << id << std::endl;
-
-  d_smt2->pushScope(true);
-  d_sygusVars = d_smt2->bindBoundVars(sortedVarNames);
-}
-
-std::unique_ptr<Command> Smt2::SynthFunFactory::mkCommand(api::Grammar* grammar)
-{
-  Debug("parser-sygus") << "...read synth fun " << d_id << std::endl;
-  d_smt2->popScope();
-  return std::unique_ptr<Command>(
-      new SynthFunCommand(d_id, d_fun, d_sygusVars, d_sort, d_isInv, grammar));
-}
-
 std::unique_ptr<Command> Smt2::invConstraint(
     const std::vector<std::string>& names)
 {
index 5fcf496374d389e976605a7bdcf54a43cdb7cfb8..1aa0ebac7e255cb164cde34241fd78df2a817008 100644 (file)
@@ -194,49 +194,6 @@ class Smt2 : public Parser
 
   void resetAssertions();
 
-  /**
-   * Class for creating instances of `SynthFunCommand`s. Creating an instance
-   * of this class pushes the scope, destroying it pops the scope.
-   */
-  class SynthFunFactory
-  {
-   public:
-    /**
-     * Creates an instance of `SynthFunFactory`.
-     *
-     * @param smt2 Pointer to the parser state
-     * @param id Name of the function to synthesize
-     * @param isInv True if the goal is to synthesize an invariant, false
-     * otherwise
-     * @param range The return type of the function-to-synthesize
-     * @param sortedVarNames The parameters of the function-to-synthesize
-     */
-    SynthFunFactory(
-        Smt2* smt2,
-        const std::string& id,
-        bool isInv,
-        api::Sort range,
-        std::vector<std::pair<std::string, api::Sort>>& sortedVarNames);
-
-    const std::vector<api::Term>& getSygusVars() const { return d_sygusVars; }
-
-    /**
-     * Create an instance of `SynthFunCommand`.
-     *
-     * @param grammar Optional grammar associated with the synth-fun command
-     * @return The instance of `SynthFunCommand`
-     */
-    std::unique_ptr<Command> mkCommand(api::Grammar* grammar);
-
-   private:
-    Smt2* d_smt2;
-    std::string d_id;
-    api::Term d_fun;
-    api::Sort d_sort;
-    bool d_isInv;
-    std::vector<api::Term> d_sygusVars;
-  };
-
   /**
    * Creates a command that adds an invariant constraint.
    *
index 9c45c0b196fbd40ba86f9be79a849d5bbad9bf29..eb03edf4f9ffeb256c5ba3fb987249689de96775 100644 (file)
@@ -581,16 +581,7 @@ api::Sort DeclareSygusVarCommand::getSort() const { return d_sort; }
 
 void DeclareSygusVarCommand::invoke(api::Solver* solver)
 {
-  try
-  {
-    solver->getSmtEngine()->declareSygusVar(
-        d_symbol, d_var.getNode(), TypeNode::fromType(d_sort.getType()));
-    d_commandStatus = CommandSuccess::instance();
-  }
-  catch (exception& e)
-  {
-    d_commandStatus = new CommandFailure(e.what());
-  }
+  d_commandStatus = CommandSuccess::instance();
 }
 
 Command* DeclareSygusVarCommand::clone() const
@@ -646,27 +637,7 @@ const api::Grammar* SynthFunCommand::getGrammar() const { return d_grammar; }
 
 void SynthFunCommand::invoke(api::Solver* solver)
 {
-  try
-  {
-    std::vector<Node> vns;
-    for (const api::Term& t : d_vars)
-    {
-      vns.push_back(Node::fromExpr(t.getExpr()));
-    }
-    solver->getSmtEngine()->declareSynthFun(
-        d_symbol,
-        Node::fromExpr(d_fun.getExpr()),
-        TypeNode::fromType(d_grammar == nullptr
-                               ? d_sort.getType()
-                               : d_grammar->resolve().getType()),
-        d_isInv,
-        vns);
-    d_commandStatus = CommandSuccess::instance();
-  }
-  catch (exception& e)
-  {
-    d_commandStatus = new CommandFailure(e.what());
-  }
+  d_commandStatus = CommandSuccess::instance();
 }
 
 Command* SynthFunCommand::clone() const
index f345bee2ead800dbdfce50f408b5c418e3acb280..d0906ce98ec87dc2f79b505eedd19f50c14a9413 100644 (file)
@@ -1065,7 +1065,6 @@ Result SmtEngine::assertFormula(const Node& formula, bool inUnsatCore)
 void SmtEngine::declareSygusVar(const std::string& id, Node var, TypeNode type)
 {
   SmtScope smts(this);
-  finishInit();
   d_sygusSolver->declareSygusVar(id, var, type);
   if (Dump.isOn("raw-benchmark"))
   {
@@ -1082,7 +1081,6 @@ void SmtEngine::declareSynthFun(const std::string& id,
                                 const std::vector<Node>& vars)
 {
   SmtScope smts(this);
-  finishInit();
   d_state->doPendingPops();
   d_sygusSolver->declareSynthFun(id, func, sygusType, isInv, vars);
 
index e4e7a02c749d3f535b2cb099c2709718767084d4..d5ab0e51f0f7e3bae622d7c3f2f31063a217876b 100644 (file)
@@ -319,6 +319,18 @@ bool SygusInterpol::solveInterpolation(const std::string& name,
                                        const TypeNode& itpGType,
                                        Node& interpol)
 {
+  // Some instructions in setSynthGrammar and mkSygusConjecture need a fully
+  // initialized solver to work properly. Notice, however, that the sub-solver
+  // created below is not fully initialized by the time those two methods are
+  // needed. Therefore, we call them while the current parent solver is in scope
+  // (i.e., before creating the sub-solver).
+  collectSymbols(axioms, conj);
+  createVariables(itpGType.isNull());
+  TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj);
+
+  Node itp = mkPredicate(name);
+  mkSygusConjecture(itp, axioms, conj);
+
   std::unique_ptr<SmtEngine> subSolver;
   initializeSubsolver(subSolver);
   // get the logic
@@ -327,17 +339,12 @@ bool SygusInterpol::solveInterpolation(const std::string& name,
   l.enableSygus();
   subSolver->setLogic(l);
 
-  collectSymbols(axioms, conj);
-  createVariables(itpGType.isNull());
   for (Node var : d_vars)
   {
     subSolver->declareSygusVar(name, var, var.getType());
   }
   std::vector<Node> vars_empty;
-  TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj);
-  Node itp = mkPredicate(name);
   subSolver->declareSynthFun(name, itp, grammarType, false, vars_empty);
-  mkSygusConjecture(itp, axioms, conj);
   Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : "
                           << d_sygusConj << ", solving for "
                           << d_sygusConj[0][0] << std::endl;
index aa4289ef30506ab48012d6e0646066b7f99142a8..8b8c6dd58f113d57bd033958210bdd96b6042b94 100644 (file)
@@ -2268,7 +2268,6 @@ void SolverBlack::testSynthFun()
   Sort null = d_solver->getNullSort();
   Sort boolean = d_solver->getBooleanSort();
   Sort integer = d_solver->getIntegerSort();
-  Sort boolToBool = d_solver->mkFunctionSort(boolean, boolean);
 
   Term nullTerm;
   Term x = d_solver->mkVar(boolean);
@@ -2289,7 +2288,6 @@ void SolverBlack::testSynthFun()
   TS_ASSERT_THROWS(d_solver->synthFun("f3", {nullTerm}, boolean),
                    CVC4ApiException&);
   TS_ASSERT_THROWS(d_solver->synthFun("f4", {}, null), CVC4ApiException&);
-  TS_ASSERT_THROWS(d_solver->synthFun("f5", {}, boolToBool), CVC4ApiException&);
   TS_ASSERT_THROWS(d_solver->synthFun("f6", {x}, boolean, g2),
                    CVC4ApiException&);
   Solver slv;