Optimizations for PBE strings (#2728)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sun, 2 Dec 2018 14:49:17 +0000 (08:49 -0600)
committerGitHub <noreply@github.com>
Sun, 2 Dec 2018 14:49:17 +0000 (08:49 -0600)
src/theory/quantifiers/sygus/sygus_invariance.cpp
src/theory/quantifiers/sygus/sygus_invariance.h
src/theory/quantifiers/sygus/sygus_unif_io.cpp
src/theory/quantifiers/sygus/sygus_unif_io.h

index 24b47b2166268d78e90db4ae3809f3a6852cc3b6..5ea01ef57cc74cdba36f993c01a5332f2c0c81cb 100644 (file)
@@ -218,15 +218,22 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds,
           NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, out, nbvre);
       Trace("sygus-pbe-cterm-debug") << "Check: " << cont << std::endl;
       Node contr = Rewriter::rewrite(cont);
-      if (contr == tds->d_false)
+      if (!contr.isConst())
+      {
+        if (d_isUniversal)
+        {
+          return false;
+        }
+      }
+      else if (contr.getConst<bool>() == d_isUniversal)
       {
         if (Trace.isOn("sygus-pbe-cterm"))
         {
           Trace("sygus-pbe-cterm")
               << "PBE-cterm : enumerator : do not consider ";
-          Trace("sygus-pbe-cterm") << nbv << " for any "
-                                   << tds->sygusToBuiltin(x) << " since "
-                                   << std::endl;
+          Trace("sygus-pbe-cterm")
+              << nbv << " for any " << tds->sygusToBuiltin(x) << " since "
+              << std::endl;
           Trace("sygus-pbe-cterm") << "   PBE-cterm :    for input example : ";
           for (unsigned j = 0, size = d_ex[ii].size(); j < size; j++)
           {
@@ -238,13 +245,13 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds,
           Trace("sygus-pbe-cterm")
               << "   PBE-cterm : and is not in output : " << out << std::endl;
         }
-        return true;
+        return !d_isUniversal;
       }
       Trace("sygus-pbe-cterm-debug2")
           << "...check failed, rewrites to : " << contr << std::endl;
     }
   }
-  return false;
+  return d_isUniversal;
 }
 
 } /* CVC4::theory::quantifiers namespace */
index 59761da5cdab849b886500490f6d7daa2209b676..02c249411d4d440dfd90063600bceda24243f0f7 100644 (file)
@@ -249,7 +249,7 @@ class DivByZeroSygusInvarianceTest : public SygusInvarianceTest
 class NegContainsSygusInvarianceTest : public SygusInvarianceTest
 {
  public:
-  NegContainsSygusInvarianceTest() {}
+  NegContainsSygusInvarianceTest() : d_isUniversal(false) {}
 
   /** initialize this invariance test
    *  e is the enumerator which we are reasoning about (associated with a synth
@@ -266,9 +266,19 @@ class NegContainsSygusInvarianceTest : public SygusInvarianceTest
             std::vector<std::vector<Node> >& ex,
             std::vector<Node>& exo,
             std::vector<unsigned>& ncind);
+  /** set universal
+   *
+   * This updates the semantics of this check such that *all* instead of some
+   * examples must fail the containment test.
+   */
+  void setUniversal() { d_isUniversal = true; }
 
  protected:
-  /** checks if contains( out_i, nvn[in_i] ) --> false for some I/O pair i. */
+  /**
+   * Checks if contains( out_i, nvn[in_i] ) --> false for some I/O pair i; if
+   * d_isUniversal is true, then we check if the rewrite holds for *all* I/O
+   * pairs.
+   */
   bool invariant(TermDbSygus* tds, Node nvn, Node x) override;
 
  private:
@@ -282,6 +292,8 @@ class NegContainsSygusInvarianceTest : public SygusInvarianceTest
    *    contains( out_i, nvn[in_i] ) ---> false
    */
   std::vector<unsigned> d_neg_con_indices;
+  /** requires not being in all examples */
+  bool d_isUniversal;
 };
 
 } /* CVC4::theory::quantifiers namespace */
index 89619639dff1a201b60c453eaecaff965c53e50c..a6e6b54c6f6b0ed0a9637fd59a5b6f4b5f6fd71b 100644 (file)
@@ -431,10 +431,13 @@ void SubsumeTrie::getLeavesInternal(const std::vector<Node>& vals,
 {
   if (index == vals.size())
   {
+    // by convention, if we did not test any points, then we consider the
+    // evaluation along the current path to be always false.
+    int rstatus = status == -2 ? -1 : status;
     Assert(!d_term.isNull());
-    Assert(std::find(v[status].begin(), v[status].end(), d_term)
-           == v[status].end());
-    v[status].push_back(d_term);
+    Assert(std::find(v[rstatus].begin(), v[rstatus].end(), d_term)
+           == v[rstatus].end());
+    v[rstatus].push_back(d_term);
   }
   else
   {
@@ -806,9 +809,13 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
               || (!d_solution.isNull()
                   && d_tds->getSygusTermSize(vcc) < d_sol_term_size)))
       {
-        Trace("sygus-pbe") << "**** SygusUnif SOLVED : " << c << " = " << vcc
-                           << std::endl;
-        Trace("sygus-pbe") << "...solved at iteration " << i << std::endl;
+        if (Trace.isOn("sygus-pbe"))
+        {
+          Trace("sygus-pbe") << "**** SygusUnif SOLVED : " << c << " = ";
+          TermDbSygus::toStreamSygus("sygus-pbe", vcc);
+          Trace("sygus-pbe") << std::endl;
+          Trace("sygus-pbe") << "...solved at iteration " << i << std::endl;
+        }
         d_solution = vcc;
         newSolution = vcc;
         d_sol_term_size = d_tds->getSygusTermSize(vcc);
@@ -867,12 +874,12 @@ bool SygusUnifIo::useStrContainsEnumeratorExclude(Node e)
         d_use_str_contains_eexc[e] = false;
         return false;
       }
+      d_use_str_contains_eexc_conditional[e] = false;
       if (eis.isConditional())
       {
         Trace("sygus-sui-enum-debug")
             << "  conditional slave : " << sn << std::endl;
-        d_use_str_contains_eexc[e] = false;
-        return false;
+        d_use_str_contains_eexc_conditional[e] = true;
       }
     }
     Trace("sygus-sui-enum-debug")
@@ -895,6 +902,9 @@ bool SygusUnifIo::getExplanationForEnumeratorExclude(
     // the output for some input/output pair. If so, then this term is never
     // useful. We generalize its explanation below.
 
+    // if the enumerator is in a conditional context, then we are stricter
+    // about when to exclude
+    bool isConditional = d_use_str_contains_eexc_conditional[e];
     if (Trace.isOn("sygus-sui-cterm-debug"))
     {
       Trace("sygus-sui-enum") << std::endl;
@@ -921,12 +931,20 @@ bool SygusUnifIo::getExplanationForEnumeratorExclude(
       else
       {
         Trace("sygus-sui-cterm-debug") << "...contained." << std::endl;
+        if (isConditional)
+        {
+          return false;
+        }
       }
     }
     if (!cmp_indices.empty())
     {
       // we check invariance with respect to a negative contains test
       NegContainsSygusInvarianceTest ncset;
+      if (isConditional)
+      {
+        ncset.setUniversal();
+      }
       ncset.init(e, d_examples, d_examples_out, cmp_indices);
       // construct the generalized explanation
       d_tds->getExplain()->getExplanationFor(e, v, exp, ncset);
@@ -992,10 +1010,12 @@ Node SygusUnifIo::constructSol(
 
   EnumCache& ecache = d_ecache[e];
 
+  bool retValMod = x.isReturnValueModified();
+
   Node ret_dt;
   if (nrole == role_equal)
   {
-    if (!x.isReturnValueModified())
+    if (!retValMod)
     {
       if (ecache.isSolved())
       {
@@ -1069,11 +1089,67 @@ Node SygusUnifIo::constructSol(
         }
       }
     }
+    // maybe we can find one in the cache
+    if (ret_dt.isNull() && !retValMod)
+    {
+      bool firstTime = true;
+      std::unordered_set<Node, NodeHashFunction> intersection;
+      std::map<size_t, std::unordered_set<Node, NodeHashFunction>>::iterator
+          pit;
+      for (size_t i = 0, nvals = x.d_vals.size(); i < nvals; i++)
+      {
+        if (x.d_vals[i].getConst<bool>())
+        {
+          pit = d_psolutions.find(i);
+          if (pit == d_psolutions.end())
+          {
+            // no cached solution
+            intersection.clear();
+            break;
+          }
+          if (firstTime)
+          {
+            intersection = pit->second;
+            firstTime = false;
+          }
+          else
+          {
+            std::vector<Node> rm;
+            for (const Node& a : intersection)
+            {
+              if (pit->second.find(a) == pit->second.end())
+              {
+                rm.push_back(a);
+              }
+            }
+            for (const Node& a : rm)
+            {
+              intersection.erase(a);
+            }
+            if (intersection.empty())
+            {
+              break;
+            }
+          }
+        }
+      }
+      if (!intersection.empty())
+      {
+        ret_dt = *intersection.begin();
+        if (Trace.isOn("sygus-sui-dt"))
+        {
+          indent("sygus-sui-dt", ind);
+          Trace("sygus-sui-dt") << "ConstructPBE: found in cache: ";
+          TermDbSygus::toStreamSygus("sygus-sui-dt", ret_dt);
+          Trace("sygus-sui-dt") << std::endl;
+        }
+      }
+    }
   }
   else if (nrole == role_string_prefix || nrole == role_string_suffix)
   {
     // check if each return value is a prefix/suffix of all open examples
-    if (!x.isReturnValueModified() || x.getCurrentRole() == nrole)
+    if (!retValMod || x.getCurrentRole() == nrole)
     {
       std::map<Node, std::vector<unsigned> > incr;
       bool isPrefix = nrole == role_string_prefix;
@@ -1264,11 +1340,12 @@ Node SygusUnifIo::constructSol(
             Assert(set_split_cond_res_index);
             Assert(split_cond_res_index < ecache_cond.d_enum_vals_res.size());
             prev = x.d_vals;
-            bool ret = x.updateContext(
-                this,
-                ecache_cond.d_enum_vals_res[split_cond_res_index],
-                sc == 1);
-            AlwaysAssert(ret);
+            x.updateContext(this,
+                            ecache_cond.d_enum_vals_res[split_cond_res_index],
+                            sc == 1);
+            // return value of above call may be false in corner cases where we
+            // must choose a non-separating condition to traverse to another
+            // strategy node
           }
 
           // recurse
@@ -1284,7 +1361,7 @@ Node SygusUnifIo::constructSol(
             std::map<Node, int> solved_cond;  // stores branch
             ecache_child.d_term_trie.getLeaves(x.d_vals, true, possible_cond);
 
-            std::map<int, std::vector<Node> >::iterator itpc =
+            std::map<int, std::vector<Node>>::iterator itpc =
                 possible_cond.find(0);
             if (itpc != possible_cond.end())
             {
@@ -1301,8 +1378,6 @@ Node SygusUnifIo::constructSol(
                       << d_tds->sygusToBuiltin(cond) << std::endl;
                 }
               }
-
-              // otherwise, guess a conditional
               if (rec_c.isNull())
               {
                 rec_c = constructBestConditional(ce, itpc->second);
@@ -1381,8 +1456,35 @@ Node SygusUnifIo::constructSol(
   }
 
   Assert(ret_dt.isNull() || ret_dt.getType() == e.getType());
-  indent("sygus-sui-dt", ind);
-  Trace("sygus-sui-dt") << "ConstructPBE: returned " << ret_dt << std::endl;
+  if (Trace.isOn("sygus-sui-dt"))
+  {
+    indent("sygus-sui-dt", ind);
+    Trace("sygus-sui-dt") << "ConstructPBE: returned ";
+    TermDbSygus::toStreamSygus("sygus-sui-dt", ret_dt);
+    Trace("sygus-sui-dt") << std::endl;
+  }
+  // remember the solution
+  if (nrole == role_equal)
+  {
+    if (!retValMod && !ret_dt.isNull())
+    {
+      for (size_t i = 0, nvals = x.d_vals.size(); i < nvals; i++)
+      {
+        if (x.d_vals[i].getConst<bool>())
+        {
+          if (Trace.isOn("sygus-sui-cache"))
+          {
+            indent("sygus-sui-cache", ind);
+            Trace("sygus-sui-cache") << "Cache solution (#" << i << ") : ";
+            TermDbSygus::toStreamSygus("sygus-sui-cache", ret_dt);
+            Trace("sygus-sui-cache") << std::endl;
+          }
+          d_psolutions[i].insert(ret_dt);
+        }
+      }
+    }
+  }
+
   return ret_dt;
 }
 
index 2f87c0552b04120e86e4ebfcdc0b2a1299d08d8b..f189353b0e1b772fe0500a8b245b12118142fd6f 100644 (file)
@@ -183,12 +183,14 @@ class SubsumeTrie
                      bool pol,
                      std::vector<Node>& subsumed_by);
   /**
-  * Get the leaves of the trie, which we store in the map v.
-  * v[-1] stores the children that always evaluate to !pol,
-  * v[1] stores the children that always evaluate to pol,
-  * v[0] stores the children that both evaluate to true and false for at least
-  * one example.
-  */
+   * Get the leaves of the trie, which we store in the map v. We consider their
+   * evaluation on points such that (pol ? vals : !vals) is true.
+   *
+   * v[-1] stores the children that always evaluate to !pol,
+   * v[1] stores the children that always evaluate to pol,
+   * v[0] stores the children that both evaluate to true and false for at least
+   * one example.
+   */
   void getLeaves(const std::vector<Node>& vals,
                  bool pol,
                  std::map<int, std::vector<Node>>& v);
@@ -300,6 +302,11 @@ class SygusUnifIo : public SygusUnif
   Node d_solution;
   /** the term size of the above solution */
   unsigned d_sol_term_size;
+  /** partial solutions
+   *
+   * Maps indices for I/O points to a list of solutions for that point.
+   */
+  std::map<size_t, std::unordered_set<Node, NodeHashFunction>> d_psolutions;
   /**
    * This flag is set to true if the solution construction was
    * non-deterministic with respect to failure/success.
@@ -427,6 +434,12 @@ class SygusUnifIo : public SygusUnif
   bool useStrContainsEnumeratorExclude(Node e);
   /** cache for the above function */
   std::map<Node, bool> d_use_str_contains_eexc;
+  /**
+   * cache for the above function, stores whether enumerators e are in
+   * a conditional context, e.g. used for enumerating the return values for
+   * leaves of ITE trees.
+   */
+  std::map<Node, bool> d_use_str_contains_eexc_conditional;
 
   /** the unification context used within constructSolution */
   UnifContextIo d_context;