Improve sygus sampling for strings (#1802)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 21 Apr 2018 12:01:59 +0000 (07:01 -0500)
committerGitHub <noreply@github.com>
Sat, 21 Apr 2018 12:01:59 +0000 (07:01 -0500)
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index 18e13fd59c97580bcae9346451de6d8df0454fc3..757256fc3d3edce062dec410483a97d62aaff5dd 100644 (file)
@@ -84,7 +84,9 @@ void SygusSampler::initialize(TypeNode tn,
   d_vars.clear();
   d_rvalue_cindices.clear();
   d_rvalue_null_cindices.clear();
+  d_rstring_alphabet.clear();
   d_var_sygus_types.clear();
+  d_const_sygus_types.clear();
   d_vars.insert(d_vars.end(), vars.begin(), vars.end());
   std::map<TypeNode, unsigned> type_to_type_id;
   unsigned type_id_counter = 0;
@@ -521,13 +523,63 @@ Node SygusSampler::getRandomValue(TypeNode tn)
   }
   else if (tn.isString() || tn.isInteger())
   {
+    // if string, determine the alphabet
+    if (tn.isString() && d_rstring_alphabet.empty())
+    {
+      Trace("sygus-sample-str-alpha")
+          << "Setting string alphabet..." << std::endl;
+      std::unordered_set<unsigned> alphas;
+      for (const std::pair<const Node, std::vector<TypeNode> >& c :
+           d_const_sygus_types)
+      {
+        if (c.first.getType().isString())
+        {
+          Trace("sygus-sample-str-alpha")
+              << "...have constant " << c.first << std::endl;
+          Assert(c.first.isConst());
+          std::vector<unsigned> svec = c.first.getConst<String>().getVec();
+          for (unsigned ch : svec)
+          {
+            alphas.insert(ch);
+          }
+        }
+      }
+      // can limit to 1 extra characters beyond those in the grammar (2 if
+      // there are none in the grammar)
+      unsigned num_fresh_char = alphas.empty() ? 2 : 1;
+      unsigned fresh_char = 0;
+      for (unsigned i = 0; i < num_fresh_char; i++)
+      {
+        while (alphas.find(fresh_char) != alphas.end())
+        {
+          fresh_char++;
+        }
+        alphas.insert(fresh_char);
+      }
+      Trace("sygus-sample-str-alpha")
+          << "Sygus sampler: limit strings alphabet to : " << std::endl
+          << " ";
+      for (unsigned ch : alphas)
+      {
+        d_rstring_alphabet.push_back(ch);
+        Trace("sygus-sample-str-alpha")
+            << " \"" << String::convertUnsignedIntToChar(ch) << "\"";
+      }
+      Trace("sygus-sample-str-alpha") << std::endl;
+    }
+
     std::vector<unsigned> vec;
     double ext_freq = .5;
-    unsigned base = 10;
+    unsigned base = tn.isString() ? d_rstring_alphabet.size() : 10;
     while (Random::getRandom().pickWithProb(ext_freq))
     {
       // add a digit
-      vec.push_back(Random::getRandom().pick(0, base));
+      unsigned digit = Random::getRandom().pick(0, base - 1);
+      if (tn.isString())
+      {
+        digit = d_rstring_alphabet[digit];
+      }
+      vec.push_back(digit);
     }
     if (tn.isString())
     {
@@ -680,6 +732,10 @@ void SygusSampler::registerSygusType(TypeNode tn)
         if (dtc.getNumArgs() == 0)
         {
           d_rvalue_null_cindices[tn].push_back(i);
+          if (sop.isConst())
+          {
+            d_const_sygus_types[sop].push_back(tn);
+          }
         }
       }
       // recurse on all subfields
index a66e7ee21717610a6442c0ac758bcc8e47d51e4c..b741b60d49ac12dc643ae83356ba1b2c31ec1bb3 100644 (file)
@@ -335,8 +335,12 @@ class SygusSampler : public LazyTrieEvaluator
   std::map<TypeNode, std::vector<unsigned> > d_rvalue_cindices;
   /** map from sygus types to non-variable nullary constructors */
   std::map<TypeNode, std::vector<unsigned> > d_rvalue_null_cindices;
+  /** the random string alphabet */
+  std::vector<unsigned> d_rstring_alphabet;
   /** map from variables to sygus types that include them */
   std::map<Node, std::vector<TypeNode> > d_var_sygus_types;
+  /** map from constants to sygus types that include them */
+  std::map<Node, std::vector<TypeNode> > d_const_sygus_types;
   /** register sygus type, intializes the above two data structures */
   void registerSygusType(TypeNode tn);
 };