Refactoring the single invocation solver (#5706)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 20 Jan 2021 23:32:06 +0000 (17:32 -0600)
committerGitHub <noreply@github.com>
Wed, 20 Jan 2021 23:32:06 +0000 (17:32 -0600)
This does an intermediate refactoring of the single invocation solver to make a few things clearer and to add preliminary support for functions that have been marked as solved by external techniques.

This is in preparation for generalizing the CAV 2015 single invocation techniques.

src/theory/quantifiers/sygus/ce_guided_single_inv.cpp
src/theory/quantifiers/sygus/ce_guided_single_inv.h
src/theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp
src/theory/quantifiers/sygus/ce_guided_single_inv_sol.h

index e9e15ef3bf5aa1ca16e8aa6137801e1a3df98cb4..10d6227e36987876fea91d5b4434713460554e8c 100644 (file)
@@ -22,6 +22,7 @@
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/quantifiers_rewriter.h"
 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
+#include "theory/quantifiers/sygus/sygus_utils.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 #include "theory/quantifiers/term_enumeration.h"
 #include "theory/quantifiers/term_util.h"
@@ -57,10 +58,23 @@ void CegSingleInv::initialize(Node q)
   d_quant = q;
   d_simp_quant = q;
   Trace("sygus-si") << "CegSingleInv::initialize : " << q << std::endl;
+
+  // decompose the conjecture
+  SygusUtils::decomposeSygusConjecture(d_quant, d_funs, d_unsolvedf, d_solvedf);
+
+  Trace("sygus-si") << "functions: " << d_funs << std::endl;
+  Trace("sygus-si") << " unsolved: " << d_unsolvedf << std::endl;
+  Trace("sygus-si") << "   solved: " << d_solvedf << std::endl;
+
   // infer single invocation-ness
 
   // get the variables
-  std::vector<Node> progs(q[0].begin(), q[0].end());
+  std::map<Node, std::vector<Node> > progVars;
+  for (const Node& sf : q[0])
+  {
+    // get its argument list
+    SygusUtils::getSygusArgumentListForSynthFun(sf, progVars[sf]);
+  }
   // compute single invocation partition
   Node qq;
   if (q[1].getKind() == NOT && q[1][0].getKind() == FORALL)
@@ -72,7 +86,7 @@ void CegSingleInv::initialize(Node q)
     qq = TermUtil::simpleNegate(q[1]);
   }
   // process the single invocation-ness of the property
-  if (!d_sip->init(progs, qq))
+  if (!d_sip->init(d_unsolvedf, qq))
   {
     Trace("sygus-si") << "...not single invocation (type mismatch)"
                       << std::endl;
@@ -87,7 +101,7 @@ void CegSingleInv::initialize(Node q)
   d_sip->getFunctions(funcs);
   for (unsigned j = 0, size = funcs.size(); j < size; j++)
   {
-    Assert(std::find(progs.begin(), progs.end(), funcs[j]) != progs.end());
+    Assert(std::find(d_funs.begin(), d_funs.end(), funcs[j]) != d_funs.end());
     d_prog_to_sol_index[funcs[j]] = j;
   }
 
@@ -158,8 +172,12 @@ void CegSingleInv::finishInit(bool syntaxRestricted)
   CegHandledStatus status = CEG_HANDLED;
   if (d_single_inv.getKind() == FORALL)
   {
-    // if the conjecture is not trivially solvable
-    if (!solveTrivial(d_single_inv))
+    // if the conjecture is trivially solvable, set the solution
+    if (solveTrivial(d_single_inv))
+    {
+      setSolution();
+    }
+    else
     {
       status = CegInstantiator::isCbqiQuant(d_single_inv);
     }
@@ -225,8 +243,9 @@ bool CegSingleInv::solve()
                     << std::endl;
   d_inst.clear();
   d_instConds.clear();
-  for (const Node& q : qs)
+  if (!qs.empty())
   {
+    Node q = qs[0];
     Assert(q.getKind() == FORALL);
     siSmt->getInstantiationTermVectors(q, d_inst);
     Trace("sygus-si") << "#instantiations of " << q << "=" << d_inst.size()
@@ -255,7 +274,8 @@ bool CegSingleInv::solve()
       Trace("sygus-si") << "  Instantiation Lemma: " << ilem << std::endl;
     }
   }
-  d_isSolved = true;
+  // set the solution
+  setSolution();
   return true;
 }
 
@@ -272,37 +292,71 @@ struct sortSiInstanceIndices {
   }
 };
 
-Node CegSingleInv::getSolution(unsigned sol_index,
+Node CegSingleInv::getSolution(size_t sol_index,
                                TypeNode stn,
                                int& reconstructed,
                                bool rconsSygus)
 {
-  Assert(d_sol != NULL);
+  Assert(sol_index < d_quant[0].getNumChildren());
+  Node f = d_quant[0][sol_index];
+  Trace("csi-sol") << "CegSingleInv::getSolution " << f << std::endl;
+  // maybe it is in the solved map already?
+  if (d_solvedf.contains(f))
+  {
+    // notice that we ignore d_solutions for solved functions
+    Trace("csi-sol") << "...return solution from annotation" << std::endl;
+    return d_solvedf.apply(f);
+  }
+  Trace("csi-sol") << "...get solution from vector" << std::endl;
+
+  Node s = d_solutions[sol_index];
+  Node sol = s.getKind() == LAMBDA ? s[1] : s;
+  // must substitute to be proper variables
   const DType& dt = stn.getDType();
   Node varList = dt.getSygusVarList();
-  Node prog = d_quant[0][sol_index];
+  d_sol->d_varList.clear();
+  Assert(d_single_inv_arg_sk.size() == varList.getNumChildren());
   std::vector< Node > vars;
+  for (size_t i = 0, nvars = d_single_inv_arg_sk.size(); i < nvars; i++)
+  {
+    Trace("csi-sol") << d_single_inv_arg_sk[i] << " ";
+    vars.push_back(d_single_inv_arg_sk[i]);
+    d_sol->d_varList.push_back(varList[i]);
+  }
+  Trace("csi-sol") << std::endl;
+  Assert(vars.size() == d_sol->d_varList.size());
+  sol = sol.substitute(vars.begin(),
+                       vars.end(),
+                       d_sol->d_varList.begin(),
+                       d_sol->d_varList.end());
+  sol = reconstructToSyntax(sol, stn, reconstructed, rconsSygus);
+  return s.getKind() == LAMBDA
+             ? NodeManager::currentNM()->mkNode(LAMBDA, s[0], sol)
+             : sol;
+}
+
+Node CegSingleInv::getSolutionFromInst(size_t index)
+{
+  Assert(d_sol != NULL);
+  Node prog = d_quant[0][index];
   Node s;
   // If it is unconstrained: either the variable does not appear in the
   // conjecture or the conjecture can be solved without a single instantiation.
   if (d_prog_to_sol_index.find(prog) == d_prog_to_sol_index.end()
       || d_inst.empty())
   {
+    TypeNode ptn = prog.getType();
+    if (ptn.isFunction())
+    {
+      ptn = ptn.getRangeType();
+    }
     Trace("csi-sol") << "Get solution for (unconstrained) " << prog << std::endl;
-    s = d_qe->getTermEnumeration()->getEnumerateTerm(dt.getSygusType(), 0);
+    s = d_qe->getTermEnumeration()->getEnumerateTerm(ptn, 0);
   }
   else
   {
     Trace("csi-sol") << "Get solution for " << prog << ", with skolems : ";
-    sol_index = d_prog_to_sol_index[prog];
-    d_sol->d_varList.clear();
-    Assert(d_single_inv_arg_sk.size() == varList.getNumChildren());
-    for( unsigned i=0; i<d_single_inv_arg_sk.size(); i++ ){
-      Trace("csi-sol") << d_single_inv_arg_sk[i] << " ";
-      vars.push_back( d_single_inv_arg_sk[i] );
-      d_sol->d_varList.push_back( varList[i] );
-    }
-    Trace("csi-sol") << std::endl;
+    size_t sol_index = d_prog_to_sol_index[prog];
 
     //construct the solution
     Trace("csi-sol") << "Sort solution return values " << sol_index << std::endl;
@@ -340,16 +394,37 @@ Node CegSingleInv::getSolution(unsigned sol_index,
       cond = TermUtil::simpleNegate(cond);
       s = nm->mkNode(ITE, cond, d_inst[uindex][sol_index], s);
     }
-    Assert(vars.size() == d_sol->d_varList.size());
-    s = s.substitute( vars.begin(), vars.end(), d_sol->d_varList.begin(), d_sol->d_varList.end() );
   }
-  d_orig_solution = s;
-
   //simplify the solution using the extended rewriter
-  Trace("csi-sol") << "Solution (pre-simplification): " << d_orig_solution << std::endl;
+  Trace("csi-sol") << "Solution (pre-simplification): " << s << std::endl;
   s = d_qe->getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(s);
   Trace("csi-sol") << "Solution (post-simplification): " << s << std::endl;
-  return reconstructToSyntax( s, stn, reconstructed, rconsSygus );
+  // wrap into lambda, as needed
+  return SygusUtils::wrapSolutionForSynthFun(prog, s);
+}
+
+void CegSingleInv::setSolution()
+{
+  // construct the solutions based on the instantiations
+  d_solutions.clear();
+  d_rcSolutions.clear();
+  Subs finalSol;
+  for (size_t i = 0, nvars = d_quant[0].getNumChildren(); i < nvars; i++)
+  {
+    // Note this is a dummy solution for solved functions, which are given
+    // solutions in the annotation but do not appear in the conjecture.
+    Node sol = getSolutionFromInst(i);
+    d_solutions.push_back(sol);
+    // haven't reconstructed to syntax yet
+    d_rcSolutions.push_back(Node::null());
+    finalSol.add(d_quant[0][i], sol);
+  }
+  d_isSolved = true;
+  if (!d_solvedf.empty())
+  {
+    // replace the final solution into the solved functions
+    finalSol.applyToRange(d_solvedf, true);
+  }
 }
 
 Node CegSingleInv::reconstructToSyntax(Node s,
@@ -357,7 +432,8 @@ Node CegSingleInv::reconstructToSyntax(Node s,
                                        int& reconstructed,
                                        bool rconsSygus)
 {
-  d_solution = s;
+  // extract the lambda body
+  Node sol = s;
   const DType& dt = stn.getDType();
 
   //reconstruct the solution into sygus if necessary
@@ -378,67 +454,27 @@ Node CegSingleInv::reconstructToSyntax(Node s,
     {
       enumLimit = options::cegqiSingleInvReconstructLimit();
     }
-    d_sygus_solution =
-        d_sol->reconstructSolution(s, stn, reconstructed, enumLimit);
+    sol = d_sol->reconstructSolution(s, stn, reconstructed, enumLimit);
     if( reconstructed==1 ){
-      Trace("csi-sol") << "Solution (post-reconstruction into Sygus): " << d_sygus_solution << std::endl;
+      Trace("csi-sol") << "Solution (post-reconstruction into Sygus): " << sol
+                       << std::endl;
     }
   }else{
     Trace("csi-sol") << "Post-process solution..." << std::endl;
-    Node prev = d_solution;
-    d_solution =
-        d_qe->getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(
-            d_solution);
-    if( prev!=d_solution ){
-      Trace("csi-sol") << "Solution (after post process) : " << d_solution << std::endl;
+    Node prev = sol;
+    sol = d_qe->getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(sol);
+    if (prev != sol)
+    {
+      Trace("csi-sol") << "Solution (after post process) : " << sol
+                       << std::endl;
     }
   }
 
-  // debug solution
-  if (!d_sol->debugSolution(d_solution))
+  if (reconstructed == -1)
   {
-    // This can happen if we encountered free variables in either the
-    // instantiation terms, or in the instantiation lemmas after postprocessing.
-    // In this case, we fail, since the solution is not valid.
-    Trace("csi-sol") << "FAIL : solution " << d_solution
-                     << " contains free constants." << std::endl;
-    Warning() <<
-        "Cannot get synth function: free constants encountered in synthesis "
-        "solution.";
-    reconstructed = -1;
-  }
-  if( Trace.isOn("cegqi-stats") ){
-    int tsize, itesize;
-    tsize = 0;itesize = 0;
-    d_sol->debugTermSize( d_orig_solution, tsize, itesize );
-    Trace("cegqi-stats") << tsize << " " << itesize << " ";
-    tsize = 0;itesize = 0;
-    d_sol->debugTermSize( d_solution, tsize, itesize );
-    Trace("cegqi-stats") << tsize << " " << itesize << " ";
-    if( !d_sygus_solution.isNull() ){
-      tsize = 0;itesize = 0;
-      d_sol->debugTermSize( d_sygus_solution, tsize, itesize );
-      Trace("cegqi-stats") << tsize << " - ";
-    }else{
-      Trace("cegqi-stats") << "null ";
-    }
-    Trace("cegqi-stats") << std::endl;
-  }
-  Node sol;
-  if( reconstructed==1 ){
-    sol = d_sygus_solution;
-  }else if( reconstructed==-1 ){
     return Node::null();
-  }else{
-    sol = d_solution;
-  }
-  //make into lambda
-  if( !dt.getSygusVarList().isNull() ){
-    Node varList = dt.getSygusVarList();
-    return NodeManager::currentNM()->mkNode( LAMBDA, varList, sol );
-  }else{
-    return sol;
   }
+  return sol;
 }
 
 void CegSingleInv::preregisterConjecture(Node q) { d_orig_conjecture = q; }
@@ -503,11 +539,11 @@ bool CegSingleInv::solveTrivial(Node q)
     }
     d_inst.push_back(inst);
     d_instConds.push_back(NodeManager::currentNM()->mkConst(true));
-    d_isSolved = true;
     return true;
   }
   Trace("sygus-si-trivial-solve")
       << q << " is not trivially solvable." << std::endl;
+
   return false;
 }
 
index 0e1ddba1fa6b559b2342a015aab96794b1ac219d..6d0abefb1369f435fc9e8d9c17cf37fbe5a6c9ea 100644 (file)
@@ -18,6 +18,7 @@
 #define CVC4__THEORY__QUANTIFIERS__CE_GUIDED_SINGLE_INV_H
 
 #include "context/cdlist.h"
+#include "expr/subs.h"
 #include "theory/quantifiers/cegqi/inst_strategy_cegqi.h"
 #include "theory/quantifiers/inst_match_trie.h"
 #include "theory/quantifiers/single_inv_partition.h"
@@ -40,7 +41,6 @@ class SynthConjecture;
 class CegSingleInv
 {
  private:
-  friend class CegqiOutputSingleInv;
   //presolve
   void collectPresolveEqTerms( Node n,
                                std::map< Node, std::vector< Node > >& teq );
@@ -58,18 +58,10 @@ class CegSingleInv
 
   // list of skolems for each argument of programs
   std::vector<Node> d_single_inv_arg_sk;
-  // list of variables/skolems for each program
-  std::vector<Node> d_single_inv_var;
-  std::vector<Node> d_single_inv_sk;
-  std::map<Node, int> d_single_inv_sk_index;
   // program to solution index
   std::map<Node, unsigned> d_prog_to_sol_index;
   // original conjecture
   Node d_orig_conjecture;
-  // solution
-  Node d_orig_solution;
-  Node d_solution;
-  Node d_sygus_solution;
 
  public:
   //---------------------------------representation of the solution
@@ -83,13 +75,15 @@ class CegSingleInv
    * first order conjecture for the term vectors above.
    */
   std::vector<Node> d_instConds;
+  /** The solutions, without reconstruction to syntax */
+  std::vector<Node> d_solutions;
+  /** The solutions, after reconstruction to syntax */
+  std::vector<Node> d_rcSolutions;
   /** is solved */
   bool d_isSolved;
   //---------------------------------end representation of the solution
 
  private:
-  // conjecture
-  Node d_quant;
   Node d_simp_quant;
   // are we single invocation?
   bool d_single_invocation;
@@ -102,8 +96,7 @@ class CegSingleInv
 
   // get simplified conjecture
   Node getSimplifiedConjecture() { return d_simp_quant; }
- public:
-  // initialize this class for synthesis conjecture q
+  /** initialize this class for synthesis conjecture q */
   void initialize( Node q );
   /** finish initialize
    *
@@ -122,8 +115,25 @@ class CegSingleInv
    * found a solution to the synthesis conjecture using this method.
    */
   bool solve();
-  //get solution
-  Node getSolution( unsigned sol_index, TypeNode stn, int& reconstructed, bool rconsSygus = true );
+  /**
+   * Get solution for the sol_index^th function to synthesize of the conjecture
+   * this class was assigned.
+   *
+   * @param sol_index The index of the function to synthesize
+   * @param stn The sygus type of the solution, which corresponds to syntactic
+   * restrictions
+   * @param reconstructed Set to the status of reconstructing the solution,
+   * where 1 = success, 0 = no reconstruction specified, -1 = failed
+   * @param rconsSygus Whether to apply sygus reconstruction techniques based
+   * on the underlying reconstruction module. If this is false, then the
+   * solution does not necessarily fit the grammar.
+   * @return the solution for the sol_index^th function to synthesize of the
+   * conjecture assigned to this class.
+   */
+  Node getSolution(size_t sol_index,
+                   TypeNode stn,
+                   int& reconstructed,
+                   bool rconsSygus = true);
   //reconstruct to syntax
   Node reconstructToSyntax( Node s, TypeNode stn, int& reconstructed,
                             bool rconsSygus = true );
@@ -140,6 +150,27 @@ class CegSingleInv
    * unsatisfiable for instantiation {x1 -> t1 ... xn -> tn}.
    */
   bool solveTrivial(Node q);
+  /**
+   * Get solution from the instantiations stored in this class (d_inst) for
+   * the index^th function to synthesize. The vector d_inst should be
+   * initialized before calling this method.
+   */
+  Node getSolutionFromInst(size_t index);
+  /**
+   * Set solution, which sets the d_solutions / d_rcSolutions fields based on
+   * calls to the above method getSolutionFromInst.
+   */
+  void setSolution();
+  /** The conjecture */
+  Node d_quant;
+  //-------------- decomposed conjecture
+  /** All functions */
+  std::vector<Node> d_funs;
+  /** Unsolved functions */
+  std::vector<Node> d_unsolvedf;
+  /** Mapping of solved functions */
+  Subs d_solvedf;
+  //-------------- end decomposed conjecture
 };
 
 }/* namespace CVC4::theory::quantifiers */
index be62d2e09a85063e71f1d91edc4b8e5b6bc812a0..1748dea8e814ef88e972d2c289f1ff350661ab29 100644 (file)
@@ -49,42 +49,6 @@ CegSingleInvSol::CegSingleInvSol(QuantifiersEngine* qe)
 {
 }
 
-bool CegSingleInvSol::debugSolution(Node sol)
-{
-  if( sol.getKind()==SKOLEM ){
-    return false;
-  }else{
-    for( unsigned i=0; i<sol.getNumChildren(); i++ ){
-      if( !debugSolution( sol[i] ) ){
-        return false;
-      }
-    }
-    return true;
-  }
-
-}
-
-void CegSingleInvSol::debugTermSize(Node sol, int& t_size, int& num_ite)
-{
-  std::map< Node, int >::iterator it = d_dterm_size.find( sol );
-  if( it==d_dterm_size.end() ){
-    int prev = t_size;
-    int prev_ite = num_ite;
-    t_size++;
-    if( sol.getKind()==ITE ){
-      num_ite++;
-    }
-    for( unsigned i=0; i<sol.getNumChildren(); i++ ){
-      debugTermSize( sol[i], t_size, num_ite );
-    }
-    d_dterm_size[sol] = t_size-prev;
-    d_dterm_ite_size[sol] = num_ite-prev_ite;
-  }else{
-    t_size += it->second;
-    num_ite += d_dterm_ite_size[sol];
-  }
-}
-
 void CegSingleInvSol::preregisterConjecture(Node q)
 {
   Trace("csi-sol") << "Preregister conjecture : " << q << std::endl;
index 6a2b23503330c706a8d35e10fec149378586dd03..418f8c00ab453d79759a4633e76631a8abd00b5f 100644 (file)
@@ -48,10 +48,6 @@ class CegSingleInvSol
   std::vector< Node > d_varList;
   std::map< Node, int > d_dterm_size;
   std::map< Node, int > d_dterm_ite_size;
-//solution simplification
-private:
-  bool debugSolution( Node sol );
-  void debugTermSize( Node sol, int& t_size, int& num_ite );
 
  public:
   CegSingleInvSol(QuantifiersEngine* qe);