Incorporate static PBE symmetry breaking lemmas into SygusEnumerator (#2690)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 6 Nov 2018 23:28:41 +0000 (17:28 -0600)
committerGitHub <noreply@github.com>
Tue, 6 Nov 2018 23:28:41 +0000 (17:28 -0600)
src/theory/datatypes/datatypes_sygus.cpp
src/theory/quantifiers/sygus/sygus_enumerator.cpp
src/theory/quantifiers/sygus/sygus_enumerator.h
src/theory/quantifiers/sygus/sygus_pbe.cpp
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h

index a7763bff144734d4b4a6ad0a22b67cd7bae618e8..4c7dcec9a6206b21848e7f9fc33a20c0641e8ff4 100644 (file)
@@ -1304,7 +1304,8 @@ void SygusSymBreakNew::preRegisterTerm( TNode n, std::vector< Node >& lemmas  )
   }
 }
 
-void SygusSymBreakNew::registerSizeTerm( Node e, std::vector< Node >& lemmas ) {
+void SygusSymBreakNew::registerSizeTerm(Node e, std::vector<Node>& lemmas)
+{
   if (d_register_st.find(e) != d_register_st.end())
   {
     // already registered
@@ -1545,15 +1546,32 @@ void SygusSymBreakNew::check( std::vector< Node >& lemmas ) {
       {
         // symmetry breaking lemmas should only be for enumerators
         Assert(d_register_st[a]);
-        std::vector<Node> sbl;
-        d_tds->getSymBreakLemmas(a, sbl);
-        for (const Node& lem : sbl)
+        // If this is a non-basic enumerator, process its symmetry breaking
+        // clauses. Since this class is not responsible for basic enumerators,
+        // their symmetry breaking clauses are ignored.
+        if (!d_tds->isBasicEnumerator(a))
         {
-          TypeNode tn = d_tds->getTypeForSymBreakLemma(lem);
-          unsigned sz = d_tds->getSizeForSymBreakLemma(lem);
-          registerSymBreakLemma(tn, lem, sz, a, lemmas);
+          std::vector<Node> sbl;
+          d_tds->getSymBreakLemmas(a, sbl);
+          for (const Node& lem : sbl)
+          {
+            if (d_tds->isSymBreakLemmaTemplate(lem))
+            {
+              // register the lemma template
+              TypeNode tn = d_tds->getTypeForSymBreakLemma(lem);
+              unsigned sz = d_tds->getSizeForSymBreakLemma(lem);
+              registerSymBreakLemma(tn, lem, sz, a, lemmas);
+            }
+            else
+            {
+              Trace("dt-sygus-debug")
+                  << "DT sym break lemma : " << lem << std::endl;
+              // it is a normal lemma
+              lemmas.push_back(lem);
+            }
+          }
+          d_tds->clearSymBreakLemmas(a);
         }
-        d_tds->clearSymBreakLemmas(a);
       }
     }
     if (!lemmas.empty())
@@ -1563,9 +1581,20 @@ void SygusSymBreakNew::check( std::vector< Node >& lemmas ) {
   }
 
   // register search values, add symmetry breaking lemmas if applicable
-  for( std::map< Node, bool >::iterator it = d_register_st.begin(); it != d_register_st.end(); ++it ){
-    if( it->second ){
-      Node prog = it->first;
+  std::vector<Node> es;
+  d_tds->getEnumerators(es);
+  bool needsRecheck = false;
+  // for each enumerator registered to d_tds
+  for (Node& prog : es)
+  {
+    if (d_register_st.find(prog) == d_register_st.end())
+    {
+      // not yet registered, do so now
+      registerSizeTerm(prog, lemmas);
+      needsRecheck = true;
+    }
+    else
+    {
       Trace("dt-sygus-debug") << "Checking model value of " << prog << "..."
                               << std::endl;
       Assert(prog.getType().isDatatype());
@@ -1624,14 +1653,12 @@ void SygusSymBreakNew::check( std::vector< Node >& lemmas ) {
       prog.setAttribute(ssbo, !isExc);
     }
   }
-  //register any measured terms that we haven't encountered yet (should only be invoked on first call to check
-  Trace("sygus-sb") << "Register size terms..." << std::endl;
-  std::vector< Node > mts;
-  d_tds->getEnumerators(mts);
-  for( unsigned i=0; i<mts.size(); i++ ){
-    registerSizeTerm( mts[i], lemmas );
+  Trace("sygus-sb") << "SygusSymBreakNew::check: finished." << std::endl;
+  if (needsRecheck)
+  {
+    Trace("sygus-sb") << " SygusSymBreakNew::rechecking..." << std::endl;
+    return check(lemmas);
   }
-  Trace("sygus-sb") << " SygusSymBreakNew::check: finished." << std::endl;
 
   if (Trace.isOn("cegqi-engine") && !d_szinfo.empty())
   {
index c3dd56127b4bd251e4549ef9e3761bdd160ca8c8..5c3e44a3334764e7bd15973d8aff1a8ed0fa04c3 100644 (file)
@@ -31,10 +31,72 @@ SygusEnumerator::SygusEnumerator(TermDbSygus* tds, SynthConjecture* p)
 
 void SygusEnumerator::initialize(Node e)
 {
+  Trace("sygus-enum") << "SygusEnumerator::initialize " << e << std::endl;
   d_enum = e;
   d_etype = d_enum.getType();
+  Assert(d_etype.isDatatype());
+  Assert(d_etype.getDatatype().isSygus());
   d_tlEnum = getMasterEnumForType(d_etype);
   d_abortSize = options::sygusAbortSize();
+
+  // Get the statically registered symmetry breaking clauses for e, see if they
+  // can be used for speeding up the enumeration.
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<Node> sbl;
+  d_tds->getSymBreakLemmas(e, sbl);
+  Node ag = d_tds->getActiveGuardForEnumerator(e);
+  Node truen = nm->mkConst(true);
+  // use TNode for substitute below
+  TNode agt = ag;
+  TNode truent = truen;
+  Assert(d_tcache.find(d_etype) != d_tcache.end());
+  const Datatype& dt = d_etype.getDatatype();
+  for (const Node& lem : sbl)
+  {
+    if (!d_tds->isSymBreakLemmaTemplate(lem))
+    {
+      // substitute its active guard by true and rewrite
+      Node slem = lem.substitute(agt, truent);
+      slem = Rewriter::rewrite(slem);
+      // break into conjuncts
+      std::vector<Node> sblc;
+      if (slem.getKind() == AND)
+      {
+        for (const Node& slemc : slem)
+        {
+          sblc.push_back(slemc);
+        }
+      }
+      else
+      {
+        sblc.push_back(slem);
+      }
+      for (const Node& sbl : sblc)
+      {
+        Trace("sygus-enum")
+            << "  symmetry breaking lemma : " << sbl << std::endl;
+        // if its a negation of a unit top-level tester, then this specifies
+        // that we should not enumerate terms whose top symbol is that
+        // constructor
+        if (sbl.getKind() == NOT)
+        {
+          Node a;
+          int tst = datatypes::DatatypesRewriter::isTester(sbl[0], a);
+          if (tst >= 0)
+          {
+            if (a == e)
+            {
+              Node cons = Node::fromExpr(dt[tst].getConstructor());
+              Trace("sygus-enum") << "  ...unit exclude constructor #" << tst
+                                  << ", constructor " << cons << std::endl;
+              d_sbExcTlCons.insert(cons);
+            }
+          }
+        }
+        // other symmetry breaking lemmas such as disjunctions are not used
+      }
+    }
+  }
 }
 
 void SygusEnumerator::addValue(Node v)
@@ -57,6 +119,17 @@ Node SygusEnumerator::getCurrent()
     }
   }
   Node ret = d_tlEnum->getCurrent();
+  if (!ret.isNull() && !d_sbExcTlCons.empty())
+  {
+    Assert(ret.hasOperator());
+    // might be excluded by an externally provided symmetry breaking clause
+    if (d_sbExcTlCons.find(ret.getOperator()) != d_sbExcTlCons.end())
+    {
+      Trace("sygus-enum-exc")
+          << "Exclude (external) : " << d_tds->sygusToBuiltin(ret) << std::endl;
+      ret = Node::null();
+    }
+  }
   if (Trace.isOn("sygus-enum"))
   {
     Trace("sygus-enum") << "Enumerate : ";
index 28f8f4126194baf97f5b986ebdec24958872bc00..af6bb03f01f7f4ce273620f43e552ec679bc2338 100644 (file)
@@ -435,6 +435,15 @@ class SygusEnumerator : public EnumValGenerator
   int d_abortSize;
   /** get master enumerator for type tn */
   TermEnum* getMasterEnumForType(TypeNode tn);
+  //-------------------------------- externally specified symmetry breaking
+  /** set of constructors we disallow at top level
+   *
+   * A constructor C is disallowed at the top level if a symmetry breaking
+   * lemma that entails ~is-C( d_enum ) was registered to
+   * TermDbSygus::registerSymBreakLemma.
+   */
+  std::unordered_set<Node, NodeHashFunction> d_sbExcTlCons;
+  //-------------------------------- end externally specified symmetry breaking
 };
 
 }  // namespace quantifiers
index f91dd5d30f99f1c6b2fdbf81502ae030b14267b0..ee724712117231c18abe6f1022798f743f52bae9 100644 (file)
@@ -228,13 +228,24 @@ bool SygusPbe::initialize(Node n,
             {
               lem = lem.substitute(tsp, te);
             }
-            disj.push_back(lem);
+            if (std::find(disj.begin(), disj.end(), lem) == disj.end())
+            {
+              disj.push_back(lem);
+            }
           }
         }
+        // add its active guard
+        Node ag = d_tds->getActiveGuardForEnumerator(e);
+        Assert(!ag.isNull());
+        disj.push_back(ag.negate());
         Node lem = disj.size() == 1 ? disj[0] : nm->mkNode(OR, disj);
         Trace("sygus-pbe") << "  static redundant op lemma : " << lem
                            << std::endl;
-        lemmas.push_back(lem);
+        // Register as a symmetry breaking lemma with the term database.
+        // This will either be processed via a lemma on the output channel
+        // of the sygus extension of the datatypes solver, or internally
+        // encoded as a constraint to an active enumerator.
+        d_tds->registerSymBreakLemma(e, lem, etn, 0, false);
       }
     }
     Trace("sygus-pbe") << "Initialize " << d_examples[c].size()
index 5cf230820a412c99d7c44d8f8fac8d3034326ebe..9198f7e56aea0833f0016afa426284ec1bc8ad2f 100644 (file)
@@ -682,6 +682,10 @@ void TermDbSygus::registerEnumerator(Node e,
     Node ag = nm->mkSkolem("eG", nm->booleanType());
     // must ensure it is a literal immediately here
     ag = d_quantEngine->getValuation().ensureLiteral(ag);
+    // must ensure that it is asserted as a literal before we begin solving
+    Node lem = nm->mkNode(OR, ag, ag.negate());
+    d_quantEngine->getOutputChannel().requirePhase(ag, true);
+    d_quantEngine->getOutputChannel().lemma(lem);
     d_enum_to_active_guard[e] = ag;
   }
 }
@@ -771,14 +775,13 @@ void TermDbSygus::getEnumerators(std::vector<Node>& mts)
   }
 }
 
-void TermDbSygus::registerSymBreakLemma(Node e,
-                                        Node lem,
-                                        TypeNode tn,
-                                        unsigned sz)
+void TermDbSygus::registerSymBreakLemma(
+    Node e, Node lem, TypeNode tn, unsigned sz, bool isTempl)
 {
   d_enum_to_sb_lemmas[e].push_back(lem);
   d_sb_lemma_to_type[lem] = tn;
   d_sb_lemma_to_size[lem] = sz;
+  d_sb_lemma_to_isTempl[lem] = isTempl;
 }
 
 bool TermDbSygus::hasSymBreakLemmas(std::vector<Node>& enums) const
@@ -817,6 +820,13 @@ unsigned TermDbSygus::getSizeForSymBreakLemma(Node lem) const
   return it->second;
 }
 
+bool TermDbSygus::isSymBreakLemmaTemplate(Node lem) const
+{
+  std::map<Node, bool>::const_iterator it = d_sb_lemma_to_isTempl.find(lem);
+  Assert(it != d_sb_lemma_to_isTempl.end());
+  return it->second;
+}
+
 void TermDbSygus::clearSymBreakLemmas(Node e) { d_enum_to_sb_lemmas.erase(e); }
 
 bool TermDbSygus::isRegistered(TypeNode tn) const
index 2e8604411d0b542a48421e28533670d9ba4035df..7a522ded6e6bbc43f6d655589233d279cff0e3c5 100644 (file)
@@ -170,12 +170,13 @@ class TermDbSygus {
    *
    * tn : the (sygus datatype) type that lem applies to, i.e. the
    * type of terms that lem blocks models for,
-   * sz : the minimum size of terms that the lem blocks.
-   *
-   * Notice that the symmetry breaking lemma template should be relative to x,
-   * where x is returned by the call to getFreeVar( tn, 0 ) in this class.
+   * sz : the minimum size of terms that the lem blocks,
+   * isTempl : if this flag is false, then lem is a (concrete) lemma.
+   * If this flag is true, then lem is a symmetry breaking lemma template
+   * over x, where x is returned by the call to getFreeVar( tn, 0 ).
    */
-  void registerSymBreakLemma(Node e, Node lem, TypeNode tn, unsigned sz);
+  void registerSymBreakLemma(
+      Node e, Node lem, TypeNode tn, unsigned sz, bool isTempl = true);
   /** Has symmetry breaking lemmas been added for any enumerator? */
   bool hasSymBreakLemmas(std::vector<Node>& enums) const;
   /** Get symmetry breaking lemmas
@@ -188,6 +189,8 @@ class TermDbSygus {
   TypeNode getTypeForSymBreakLemma(Node lem) const;
   /** Get the minimum size of terms symmetry breaking lemma lem applies to */
   unsigned getSizeForSymBreakLemma(Node lem) const;
+  /** Returns true if lem is a lemma template, false if lem is a lemma */
+  bool isSymBreakLemmaTemplate(Node lem) const;
   /** Clear information about symmetry breaking lemmas for enumerator e */
   void clearSymBreakLemmas(Node e);
   //------------------------------end enumerators
@@ -344,6 +347,8 @@ class TermDbSygus {
   std::map<Node, TypeNode> d_sb_lemma_to_type;
   /** mapping from symmetry breaking lemmas to size */
   std::map<Node, unsigned> d_sb_lemma_to_size;
+  /** mapping from symmetry breaking lemmas to whether they are templates */
+  std::map<Node, bool> d_sb_lemma_to_isTempl;
   /** enumerators to whether they are actively-generated */
   std::map<Node, bool> d_enum_active_gen;
   /** enumerators to whether they are variable agnostic */