Allow partial models with optimized sygus enumeration (#2682)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 5 Nov 2018 15:25:33 +0000 (09:25 -0600)
committerGitHub <noreply@github.com>
Mon, 5 Nov 2018 15:25:33 +0000 (09:25 -0600)
src/theory/quantifiers/sygus/enum_stream_substitution.cpp
src/theory/quantifiers/sygus/enum_stream_substitution.h
src/theory/quantifiers/sygus/sygus_enumerator.cpp
src/theory/quantifiers/sygus/sygus_enumerator.h
src/theory/quantifiers/sygus/synth_conjecture.cpp
src/theory/quantifiers/sygus/synth_conjecture.h
src/theory/quantifiers/sygus/term_database_sygus.cpp
test/regress/regress1/sygus/commutative-stream.sy

index 24770ade034dc4dc0aa3e8a9650544e8e6b6ee4e..e8daa425688d2d37cbe890c08585864faab6f179 100644 (file)
@@ -608,8 +608,17 @@ bool EnumStreamSubstitution::CombinationState::getNextCombination()
 }
 
 void EnumStreamConcrete::initialize(Node e) { d_ess.initialize(e.getType()); }
-void EnumStreamConcrete::addValue(Node v) { d_ess.resetValue(v); }
-Node EnumStreamConcrete::getNext() { return d_ess.getNext(); }
+void EnumStreamConcrete::addValue(Node v)
+{
+  d_ess.resetValue(v);
+  d_currTerm = d_ess.getNext();
+}
+bool EnumStreamConcrete::increment()
+{
+  d_currTerm = d_ess.getNext();
+  return !d_currTerm.isNull();
+}
+Node EnumStreamConcrete::getCurrent() { return d_currTerm; }
 }  // namespace quantifiers
 }  // namespace theory
 }  // namespace CVC4
index 38fa0627bbc14d2d2dc2489fc79e224c1b72ce3c..476a364ea55725c4da884c2d3cfa26eed506b849 100644 (file)
@@ -286,12 +286,16 @@ class EnumStreamConcrete : public EnumValGenerator
   void initialize(Node e) override;
   /** get that value v was enumerated */
   void addValue(Node v) override;
-  /** get the next value enumerated by this class */
-  Node getNext() override;
+  /** increment */
+  bool increment() override;
+  /** get the current value enumerated by this class */
+  Node getCurrent() override;
 
  private:
   /** stream substitution utility */
   EnumStreamSubstitution d_ess;
+  /** the current term generated by this class */
+  Node d_currTerm;
 };
 
 }  // namespace quantifiers
index a39c9e958341a3e36e5e491d4723f13b4b35f015..aab580650b93ff07eccb6b193a20383c8c29b83f 100644 (file)
@@ -25,11 +25,7 @@ namespace theory {
 namespace quantifiers {
 
 SygusEnumerator::SygusEnumerator(TermDbSygus* tds, SynthConjecture* p)
-    : d_tds(tds),
-      d_parent(p),
-      d_tlEnum(nullptr),
-      d_abortSize(-1),
-      d_firstTime(false)
+    : d_tds(tds), d_parent(p), d_tlEnum(nullptr), d_abortSize(-1)
 {
 }
 
@@ -39,7 +35,6 @@ void SygusEnumerator::initialize(Node e)
   d_etype = d_enum.getType();
   d_tlEnum = getMasterEnumForType(d_etype);
   d_abortSize = options::sygusAbortSize();
-  d_firstTime = true;
 }
 
 void SygusEnumerator::addValue(Node v)
@@ -47,17 +42,9 @@ void SygusEnumerator::addValue(Node v)
   // do nothing
 }
 
-Node SygusEnumerator::getNext()
+bool SygusEnumerator::increment() { return d_tlEnum->increment(); }
+Node SygusEnumerator::getCurrent()
 {
-  if (d_firstTime)
-  {
-    d_firstTime = false;
-  }
-  else if (!d_tlEnum->increment())
-  {
-    // no more values
-    return Node::null();
-  }
   if (d_abortSize >= 0)
   {
     int cs = static_cast<int>(d_tlEnum->getCurrentSize());
@@ -70,8 +57,12 @@ Node SygusEnumerator::getNext()
     }
   }
   Node ret = d_tlEnum->getCurrent();
-  Trace("sygus-enum") << "Enumerate : " << d_tds->sygusToBuiltin(ret)
-                      << std::endl;
+  if (Trace.isOn("sygus-enum"))
+  {
+    Trace("sygus-enum") << "Enumerate : ";
+    TermDbSygus::toStreamSygus("sygus-enum", ret);
+    Trace("sygus-enum") << std::endl;
+  }
   return ret;
 }
 
@@ -241,6 +232,7 @@ bool SygusEnumerator::TermCache::addTerm(Node n)
     d_terms.push_back(n);
     return true;
   }
+  Assert(!n.isNull());
   if (options::sygusSymBreakDynamic())
   {
     Node bn = d_tds->sygusToBuiltin(n);
@@ -387,7 +379,7 @@ bool SygusEnumerator::TermEnumSlave::validateIndex()
   Trace("sygus-enum-debug2") << "slave(" << d_tn << ") : validate index...\n";
   SygusEnumerator::TermCache& tc = d_se->d_tcache[d_tn];
   // ensure that index is in the range
-  if (d_index >= tc.getNumTerms())
+  while (d_index >= tc.getNumTerms())
   {
     Assert(d_index == tc.getNumTerms());
     Trace("sygus-enum-debug2") << "slave(" << d_tn << ") : force master...\n";
@@ -497,6 +489,7 @@ SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn)
 SygusEnumerator::TermEnumMaster::TermEnumMaster()
     : TermEnum(),
       d_isIncrementing(false),
+      d_currTermSet(false),
       d_consClassNum(0),
       d_ccWeight(0),
       d_consNum(0),
@@ -518,6 +511,7 @@ bool SygusEnumerator::TermEnumMaster::initialize(SygusEnumerator* se,
   d_currChildSize = 0;
   d_ccCons.clear();
   d_isIncrementing = false;
+  d_currTermSet = false;
   bool ret = increment();
   Trace("sygus-enum-debug") << "master(" << tn
                             << "): finish init, ret = " << ret << "\n";
@@ -526,10 +520,11 @@ bool SygusEnumerator::TermEnumMaster::initialize(SygusEnumerator* se,
 
 Node SygusEnumerator::TermEnumMaster::getCurrent()
 {
-  if (!d_currTerm.isNull())
+  if (d_currTermSet)
   {
     return d_currTerm;
   }
+  d_currTermSet = true;
   // construct based on the children
   std::vector<Node> children;
   const Datatype& dt = d_tn.getDatatype();
@@ -541,7 +536,13 @@ Node SygusEnumerator::TermEnumMaster::getCurrent()
   for (unsigned i = 0, nargs = dt[cnum].getNumArgs(); i < nargs; i++)
   {
     Assert(d_children.find(i) != d_children.end());
-    children.push_back(d_children[i].getCurrent());
+    Node cc = d_children[i].getCurrent();
+    if (cc.isNull())
+    {
+      d_currTerm = cc;
+      return cc;
+    }
+    children.push_back(cc);
   }
   d_currTerm = NodeManager::currentNM()->mkNode(APPLY_CONSTRUCTOR, children);
   return d_currTerm;
@@ -708,21 +709,27 @@ bool SygusEnumerator::TermEnumMaster::incrementInternal()
     Assert(d_childrenValid == d_ccTypes.size());
 
     // do we have more constructors for the given children?
-    while (d_consNum < d_ccCons.size())
+    if (d_consNum < d_ccCons.size())
     {
       Trace("sygus-enum-debug2") << "master(" << d_tn << "): try constructor "
                                  << d_consNum << std::endl;
       // increment constructor index
       // we will build for the current constructor and the given children
       d_consNum++;
+      d_currTermSet = false;
       d_currTerm = Node::null();
       Node c = getCurrent();
-      if (tc.addTerm(c))
+      if (!c.isNull())
       {
-        return true;
+        if (!tc.addTerm(c))
+        {
+          // the term was not unique based on rewriting
+          Trace("sygus-enum-debug2") << "master(" << d_tn
+                                     << "): failed addTerm\n";
+          d_currTerm = Node::null();
+        }
       }
-      Trace("sygus-enum-debug2") << "master(" << d_tn << "): failed addTerm\n";
-      // the term was not unique based on rewriting
+      return true;
     }
 
     // finished constructors for this set of children, must increment children
index 10a44da03437740973be90b40c86f1ddbaa1f2c2..28f8f4126194baf97f5b986ebdec24958872bc00 100644 (file)
@@ -50,8 +50,10 @@ class SygusEnumerator : public EnumValGenerator
   void initialize(Node e) override;
   /** Inform this generator that abstract value v was enumerated. */
   void addValue(Node v) override;
+  /** increment */
+  bool increment() override;
   /** Get the next concrete value generated by this class. */
-  Node getNext() override;
+  Node getCurrent() override;
 
  private:
   /** pointer to term database sygus */
@@ -322,6 +324,8 @@ class SygusEnumerator : public EnumValGenerator
     bool d_isIncrementing;
     /** cache for getCurrent() */
     Node d_currTerm;
+    /** is d_currTerm set */
+    bool d_currTermSet;
     //----------------------------- current constructor class information
     /** the next constructor class we are using */
     unsigned d_consClassNum;
@@ -429,8 +433,6 @@ class SygusEnumerator : public EnumValGenerator
   TermEnum* d_tlEnum;
   /** the abort size, caches the value of --sygus-abort-size */
   int d_abortSize;
-  /** this flag is true for the first time to getNext() after initialize(e) */
-  bool d_firstTime;
   /** get master enumerator for type tn */
   TermEnum* getMasterEnumForType(TypeNode tn);
 };
index e668b22061f7964a30b7e42b946353d68647ea79..03344d2e78c8200f096f23056258187a830c7ef2 100644 (file)
@@ -346,14 +346,17 @@ bool SynthConjecture::doCheck(std::vector<Node>& lems)
   {
     // get the model value of the relevant terms from the master module
     std::vector<Node> enum_values;
-    bool fullModel = getEnumeratedValues(terms, enum_values);
+    bool activeIncomplete = false;
+    bool fullModel = getEnumeratedValues(terms, enum_values, activeIncomplete);
 
     // if the master requires a full model and the model is partial, we fail
     if (!d_master->allowPartialModel() && !fullModel)
     {
       // we retain the values in d_ev_active_gen_waiting
       Trace("cegqi-engine") << "...partial model, fail." << std::endl;
-      return true;
+      // if we are partial due to an active enumerator, we may still succeed
+      // on the next call
+      return !activeIncomplete;
     }
     // the waiting values are passed to the module below, clear
     d_ev_active_gen_waiting.clear();
@@ -396,7 +399,7 @@ bool SynthConjecture::doCheck(std::vector<Node>& lems)
     if (emptyModel)
     {
       Trace("cegqi-engine") << "...empty model, fail." << std::endl;
-      return true;
+      return !activeIncomplete;
     }
     Assert(candidate_values.empty());
     constructed_cand = d_master->constructCandidates(
@@ -650,7 +653,8 @@ void SynthConjecture::preregisterConjecture(Node q)
 }
 
 bool SynthConjecture::getEnumeratedValues(std::vector<Node>& n,
-                                          std::vector<Node>& v)
+                                          std::vector<Node>& v,
+                                          bool& activeIncomplete)
 {
   std::vector<Node> ncheck = n;
   n.clear();
@@ -670,7 +674,7 @@ bool SynthConjecture::getEnumeratedValues(std::vector<Node>& n,
         continue;
       }
     }
-    Node nv = getEnumeratedValue(e);
+    Node nv = getEnumeratedValue(e, activeIncomplete);
     n.push_back(e);
     v.push_back(nv);
     ret = ret && !nv.isNull();
@@ -692,42 +696,50 @@ class EnumValGeneratorBasic : public EnumValGenerator
   /** initialize (do nothing) */
   void initialize(Node e) override {}
   /** initialize (do nothing) */
-  void addValue(Node v) override {}
+  void addValue(Node v) override { d_currTerm = *d_te; }
   /**
    * Get next returns the next (T-rewriter-unique) value based on the type
    * enumerator.
    */
-  Node getNext() override
+  bool increment() override
   {
     if (d_te.isFinished())
     {
-      return Node::null();
+      d_currTerm = Node::null();
+      return false;
     }
-    Node next = *d_te;
+    d_currTerm = *d_te;
     ++d_te;
     if (options::sygusSymBreakDynamic())
     {
-      Node nextb = d_tds->sygusToBuiltin(next);
+      Node nextb = d_tds->sygusToBuiltin(d_currTerm);
       nextb = d_tds->getExtRewriter()->extendedRewrite(nextb);
-      if (d_cache.find(nextb) != d_cache.end())
+      if (d_cache.find(nextb) == d_cache.end())
+      {
+        d_cache.insert(nextb);
+        // only return the current term if not unique
+      }
+      else
       {
-        return getNext();
+        d_currTerm = Node::null();
       }
-      d_cache.insert(nextb);
     }
-    return next;
+    return true;
   }
-
+  /** get the current term */
+  Node getCurrent() override { return d_currTerm; }
  private:
   /** pointer to term database sygus */
   TermDbSygus* d_tds;
   /** the type enumerator */
   TypeEnumerator d_te;
+  /** the current term */
+  Node d_currTerm;
   /** cache of (enumerated) builtin values we have enumerated so far */
   std::unordered_set<Node, NodeHashFunction> d_cache;
 };
 
-Node SynthConjecture::getEnumeratedValue(Node e)
+Node SynthConjecture::getEnumeratedValue(Node e, bool& activeIncomplete)
 {
   bool isEnum = d_tds->isEnumerator(e);
 
@@ -790,6 +802,7 @@ Node SynthConjecture::getEnumeratedValue(Node e)
   // Check if there is an (abstract) value absE we were actively generating
   // values based on.
   Node absE = d_ev_curr_active_gen[e];
+  bool firstTime = false;
   if (absE.isNull())
   {
     // None currently exist. The next abstract value is the model value for e.
@@ -804,9 +817,22 @@ Node SynthConjecture::getEnumeratedValue(Node e)
     }
     d_ev_curr_active_gen[e] = absE;
     iteg->second->addValue(absE);
+    firstTime = true;
+  }
+  bool inc = true;
+  if (!firstTime)
+  {
+    inc = iteg->second->increment();
+  }
+  Node v;
+  if (inc)
+  {
+    v = iteg->second->getCurrent();
   }
-  Node v = iteg->second->getNext();
-  if (v.isNull())
+  Trace("sygus-active-gen-debug") << "...generated " << v
+                                  << ", with increment success : " << inc
+                                  << std::endl;
+  if (!inc)
   {
     // No more concrete values generated from absE.
     NodeManager* nm = NodeManager::currentNM();
@@ -852,7 +878,14 @@ Node SynthConjecture::getEnumeratedValue(Node e)
   else
   {
     // We are waiting to send e -> v to the module that requested it.
-    d_ev_active_gen_waiting[e] = v;
+    if (v.isNull())
+    {
+      activeIncomplete = true;
+    }
+    else
+    {
+      d_ev_active_gen_waiting[e] = v;
+    }
     if (Trace.isOn("sygus-active-gen"))
     {
       Trace("sygus-active-gen") << "Active-gen : " << e << " : ";
index ef1b4459f743043383c05cfab8a102b23550ebca..3a43eb83df3960b8eab494e8697f0065bfbb1d95 100644 (file)
@@ -49,8 +49,14 @@ class EnumValGenerator
   virtual void initialize(Node e) = 0;
   /** Inform this generator that abstract value v was enumerated. */
   virtual void addValue(Node v) = 0;
-  /** Get the next concrete value generated by this class. */
-  virtual Node getNext() = 0;
+  /**
+   * Increment this value generator. If this returns false, then we are out of
+   * values. If this returns true, getCurrent(), if non-null, returns the
+   * current term.
+   */
+  virtual bool increment() = 0;
+  /** Get the current concrete value generated by this class. */
+  virtual Node getCurrent() = 0;
 };
 
 /** a synthesis conjecture
@@ -193,15 +199,25 @@ class SynthConjecture
    * Get model values for terms n, store in vector v. This method returns true
    * if and only if all values added to v are non-null.
    *
+   * The argument activeIncomplete indicates whether n contains an active
+   * enumerator that is currently not finished enumerating values, but returned
+   * null on a call to getEnumeratedValue. This value is used for determining
+   * whether we should call getEnumeratedValues again within a call to
+   * SynthConjecture::check.
+   *
    * It removes terms from n that correspond to "inactive" enumerators, that
    * is, enumerators whose values have been exhausted.
    */
-  bool getEnumeratedValues(std::vector<Node>& n, std::vector<Node>& v);
+  bool getEnumeratedValues(std::vector<Node>& n,
+                           std::vector<Node>& v,
+                           bool& activeIncomplete);
   /**
    * Get model value for term n. If n has a value that was excluded by
-   * datatypes sygus symmetry breaking, this method returns null.
+   * datatypes sygus symmetry breaking, this method returns null. It sets
+   * activeIncomplete to true if there is an actively-generated enumerator whose
+   * current value is null but it has not finished generating values.
    */
-  Node getEnumeratedValue(Node n);
+  Node getEnumeratedValue(Node n, bool& activeIncomplete);
   /** enumerator generators for each actively-generated enumerator */
   std::map<Node, std::unique_ptr<EnumValGenerator> > d_evg;
   /**
index 435e1a00f2404491c05726fe24f4f80ad0f62d1e..9af9900865f1f12bf66f0bbc00c28ff87427db11 100644 (file)
@@ -756,9 +756,16 @@ void TermDbSygus::toStreamSygus(const char* c, Node n)
 {
   if (Trace.isOn(c))
   {
-    std::stringstream ss;
-    Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, n);
-    Trace(c) << ss.str();
+    if (n.isNull())
+    {
+      Trace(c) << n;
+    }
+    else
+    {
+      std::stringstream ss;
+      Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, n);
+      Trace(c) << ss.str();
+    }
   }
 }
 
index 7b96a2bf32bd997c6d5078b95ecea4dea39889a9..8203fd9cf8b2493eff29bfdeab63c7a48ee2d75b 100644 (file)
@@ -3,7 +3,7 @@
 ; EXPECT: (error "Maximum term size (2) for enumerative SyGuS exceeded.")
 ; EXIT: 1
 
-; COMMAND-LINE: --sygus-stream --sygus-abort-size=2 --decision=justification
+; COMMAND-LINE: --sygus-stream --sygus-abort-size=2 --sygus-active-gen=none --decision=justification
 
 (set-logic LIA)