Add additional ground terms to SyGuS instantiation grammar (#5167)
authorMathias Preiner <mathias.preiner@gmail.com>
Thu, 1 Oct 2020 20:48:37 +0000 (13:48 -0700)
committerGitHub <noreply@github.com>
Thu, 1 Oct 2020 20:48:37 +0000 (13:48 -0700)
This PR adds options to add additional ground terms to the SyGuS instantiation grammars.

src/options/quantifiers_options.toml
src/theory/quantifiers/sygus_inst.cpp
src/theory/quantifiers/sygus_inst.h
src/theory/quantifiers_engine.cpp

index 4b130158ca8fb63d8c40ce4c1df0db9d988919d4..724a2ef2bb6b185a077690f160cc37c5c1da49b1 100644 (file)
@@ -1995,3 +1995,39 @@ header = "options/quantifiers_options.h"
   type       = "bool"
   default    = "false"
   help       = "Enable SyGuS instantiation quantifiers module"
+
+[[option]]
+  name       = "sygusInstScope"
+  category   = "regular"
+  long       = "sygus-inst-scope=MODE"
+  type       = "SygusInstScope"
+  default    = "IN"
+  help       = "select scope of ground terms"
+  help_mode  = "scope for collecting ground terms for the grammar."
+[[option.mode.IN]]
+  name = "in"
+  help = "use ground terms inside given quantified formula only."
+[[option.mode.OUT]]
+  name = "out"
+  help = "use ground terms outside of quantified formulas only."
+[[option.mode.BOTH]]
+  name = "both"
+  help = "combines inside and outside."
+
+[[option]]
+  name       = "sygusInstTermSel"
+  category   = "regular"
+  long       = "sygus-inst-term-sel=MODE"
+  type       = "SygusInstTermSelMode"
+  default    = "MIN"
+  help       = "granularity for ground terms"
+  help_mode  = "Ground term selection modes."
+[[option.mode.MIN]]
+  name = "min"
+  help = "collect minimal ground terms only."
+[[option.mode.MAX]]
+  name = "max"
+  help = "collect maximal ground terms only."
+[[option.mode.BOTH]]
+  name = "both"
+  help = "combines minimal and maximal ."
index f9a6456e1de964898b5df0eed7590e821a7144f3..4192ca746c5341eaabb4566aa752c4160a932ee5 100644 (file)
@@ -29,10 +29,142 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
+namespace {
+
+/**
+ * Collect maximal ground terms with type tn in node n.
+ *
+ * @param n: Node to traverse.
+ * @param tn: Collects only terms with type tn.
+ * @param terms: Collected terms.
+ * @param cache: Caches visited nodes.
+ * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
+ */
+void getMaxGroundTerms(TNode n,
+                       TypeNode tn,
+                       std::unordered_set<Node, NodeHashFunction>& terms,
+                       std::unordered_set<TNode, TNodeHashFunction>& cache,
+                       bool skip_quant = false)
+{
+  if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
+      && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
+  {
+    return;
+  }
+
+  Trace("sygus-inst-term") << "Find maximal terms with type " << tn
+                           << " in: " << n << std::endl;
+
+  Node cur;
+  std::vector<TNode> visit;
+
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+
+    if (cache.find(cur) != cache.end())
+    {
+      continue;
+    }
+    cache.insert(cur);
+
+    if (expr::hasBoundVar(cur) || cur.getType() != tn)
+    {
+      if (!skip_quant || cur.getKind() != kind::FORALL)
+      {
+        visit.insert(visit.end(), cur.begin(), cur.end());
+      }
+    }
+    else
+    {
+      terms.insert(cur);
+      Trace("sygus-inst-term") << "  found: " << cur << std::endl;
+    }
+  } while (!visit.empty());
+}
+
+/*
+ * Collect minimal ground terms with type tn in node n.
+ *
+ * @param n: Node to traverse.
+ * @param tn: Collects only terms with type tn.
+ * @param terms: Collected terms.
+ * @param cache: Caches visited nodes and flags indicating whether a minimal
+ *               term was already found in a subterm.
+ * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
+ */
+void getMinGroundTerms(
+    TNode n,
+    TypeNode tn,
+    std::unordered_set<Node, NodeHashFunction>& terms,
+    std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
+    bool skip_quant = false)
+{
+  if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
+      && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
+  {
+    return;
+  }
+
+  Trace("sygus-inst-term") << "Find minimal terms with type " << tn
+                           << " in: " << n << std::endl;
+
+  Node cur;
+  std::vector<TNode> visit;
+
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+
+    auto it = cache.find(cur);
+    if (it == cache.end())
+    {
+      cache.emplace(cur, std::make_pair(false, false));
+      if (!skip_quant || cur.getKind() != kind::FORALL)
+      {
+        visit.push_back(cur);
+        visit.insert(visit.end(), cur.begin(), cur.end());
+      }
+    }
+    /* up-traversal */
+    else if (!it->second.first)
+    {
+      bool found_min_term = false;
+
+      /* Check if we found a minimal term in one of the children. */
+      for (const Node& c : cur)
+      {
+        found_min_term |= cache[c].second;
+        if (found_min_term) break;
+      }
+
+      /* If we haven't found a minimal term yet, add this term if it has the
+       * right type. */
+      if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
+      {
+        terms.insert(cur);
+        found_min_term = true;
+        Trace("sygus-inst-term") << "  found: " << cur << std::endl;
+      }
+
+      it->second.first = true;
+      it->second.second = found_min_term;
+    }
+  } while (!visit.empty());
+}
+
+}  // namespace
+
 SygusInst::SygusInst(QuantifiersEngine* qe)
     : QuantifiersModule(qe),
       d_lemma_cache(qe->getUserContext()),
-      d_ce_lemma_added(qe->getUserContext())
+      d_ce_lemma_added(qe->getUserContext()),
+      d_global_terms(qe->getUserContext()),
+      d_notified_assertions(qe->getUserContext())
 {
 }
 
@@ -149,14 +281,79 @@ void SygusInst::registerQuantifier(Node q)
   std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
   std::unordered_set<Node, NodeHashFunction> term_irrelevant;
 
-  /* Collect extra symbols in 'q' to be used in the grammar. */
-  std::unordered_set<Node, NodeHashFunction> syms;
-  expr::getSymbols(q, syms);
-  for (const TNode& var : syms)
+  /* Collect relevant local ground terms for each variable type. */
+  if (options::sygusInstScope() == options::SygusInstScope::IN
+      || options::sygusInstScope() == options::SygusInstScope::BOTH)
+  {
+    std::unordered_map<TypeNode,
+                       std::unordered_set<Node, NodeHashFunction>,
+                       TypeNodeHashFunction>
+        relevant_terms;
+    for (const Node& var : q[0])
+    {
+      TypeNode tn = var.getType();
+
+      /* Collect relevant ground terms for type tn. */
+      if (relevant_terms.find(tn) == relevant_terms.end())
+      {
+        std::unordered_set<Node, NodeHashFunction> terms;
+        std::unordered_set<TNode, TNodeHashFunction> cache_max;
+        std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
+            cache_min;
+
+        getMinGroundTerms(q, tn, terms, cache_min);
+        getMaxGroundTerms(q, tn, terms, cache_max);
+        relevant_terms.emplace(tn, terms);
+      }
+
+      /* Add relevant ground terms to grammar. */
+      auto& terms = relevant_terms[tn];
+      for (const auto& t : terms)
+      {
+        TypeNode ttn = t.getType();
+        extra_cons[ttn].insert(t);
+        Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
+      }
+    }
+  }
+
+  /* Collect relevant global ground terms for each variable type. */
+  if (options::sygusInstScope() == options::SygusInstScope::OUT
+      || options::sygusInstScope() == options::SygusInstScope::BOTH)
   {
-    TypeNode tn = var.getType();
-    extra_cons[tn].insert(var);
-    Trace("sygus-inst") << "Found symbol: " << var << std::endl;
+    for (const Node& var : q[0])
+    {
+      TypeNode tn = var.getType();
+
+      /* Collect relevant ground terms for type tn. */
+      if (d_global_terms.find(tn) == d_global_terms.end())
+      {
+        std::unordered_set<Node, NodeHashFunction> terms;
+        std::unordered_set<TNode, TNodeHashFunction> cache_max;
+        std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
+            cache_min;
+
+        for (const Node& a : d_notified_assertions)
+        {
+          getMinGroundTerms(a, tn, terms, cache_min, true);
+          getMaxGroundTerms(a, tn, terms, cache_max, true);
+        }
+        d_global_terms.insert(tn, terms);
+      }
+
+      /* Add relevant ground terms to grammar. */
+      auto it = d_global_terms.find(tn);
+      if (it != d_global_terms.end())
+      {
+        for (const auto& t : (*it).second)
+        {
+          TypeNode ttn = t.getType();
+          extra_cons[ttn].insert(t);
+          Trace("sygus-inst")
+              << "Adding (global) extra cons: " << t << std::endl;
+        }
+      }
+    }
   }
 
   /* Construct grammar for each bound variable of 'q'. */
@@ -190,6 +387,14 @@ void SygusInst::preRegisterQuantifier(Node q)
   addCeLemma(q);
 }
 
+void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
+{
+  for (const Node& a : assertions)
+  {
+    d_notified_assertions.insert(a);
+  }
+}
+
 /*****************************************************************************/
 /* private methods                                                           */
 /*****************************************************************************/
index 2361c4a2b023f9a1e63f93ee8390be6744ccf748..c95c6a02658685549847e2311499aba33107da7f 100644 (file)
@@ -82,6 +82,9 @@ class SygusInst : public QuantifiersModule
   /* Called once for every quantifier 'q' per context. */
   void preRegisterQuantifier(Node q) override;
 
+  /* For collecting global terms from all available assertions. */
+  void ppNotifyAssertions(const std::vector<Node>& assertions);
+
   std::string identify() const override { return "SygusInst"; }
 
  private:
@@ -124,6 +127,15 @@ class SygusInst : public QuantifiersModule
   /* Indicates whether a counterexample lemma was added for a quantified
    * formula in the current context. */
   context::CDHashSet<Node, NodeHashFunction> d_ce_lemma_added;
+
+  /* Set of global ground terms in assertions (outside of quantifiers). */
+  context::CDHashMap<TypeNode,
+                     std::unordered_set<Node, NodeHashFunction>,
+                     TypeNodeHashFunction>
+      d_global_terms;
+
+  /* Assertions sent by ppNotifyAssertions. */
+  context::CDHashSet<Node, NodeHashFunction> d_notified_assertions;
 };
 
 }  // namespace quantifiers
index 557d444d67dee256fb2614c6e8b672c1183b4e53..cceb04d9fc8264852e86b905d63b58b43d5cf0cf 100644 (file)
@@ -370,6 +370,14 @@ void QuantifiersEngine::ppNotifyAssertions(
       sye->preregisterAssertion(a);
     }
   }
+  /* The SyGuS instantiation module needs a global view of all available
+   * assertions to collect global terms that get added to each grammar.
+   */
+  if (options::sygusInst())
+  {
+    quantifiers::SygusInst* si = d_qmodules->d_sygus_inst.get();
+    si->ppNotifyAssertions(assertions);
+  }
 }
 
 void QuantifiersEngine::check( Theory::Effort e ){
@@ -976,8 +984,11 @@ void QuantifiersEngine::flushLemmas(){
     //take default output channel if none is provided
     d_hasAddedLemma = true;
     std::map<Node, ProofGenerator*>::iterator itp;
-    for (const Node& lemw : d_lemmas_waiting)
+    // Note: Do not use foreach loop here and do not cache size() call.
+    // New lemmas can be added while iterating over d_lemmas_waiting.
+    for (size_t i = 0; i < d_lemmas_waiting.size(); ++i)
     {
+      const Node& lemw = d_lemmas_waiting[i];
       Trace("qe-lemma") << "Lemma : " << lemw << std::endl;
       itp = d_lemmasWaitingPg.find(lemw);
       if (itp != d_lemmasWaitingPg.end())