Refactor symmetry breaking in datatypes sygus (#1640)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 6 Mar 2018 20:22:43 +0000 (14:22 -0600)
committerGitHub <noreply@github.com>
Tue, 6 Mar 2018 20:22:43 +0000 (14:22 -0600)
src/theory/datatypes/datatypes_sygus.cpp
src/theory/datatypes/datatypes_sygus.h
src/theory/quantifiers/sygus/ce_guided_conjecture.cpp
src/theory/quantifiers/sygus/sygus_explain.cpp
src/theory/quantifiers/sygus/sygus_explain.h
src/theory/quantifiers/sygus/sygus_pbe.cpp
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h

index b8185e9c8a96ab0f0069729b3008240192da99f8..1779ab27b4c089d57cb246847db7ac8fa211e0c1 100644 (file)
@@ -669,14 +669,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned
 }
 
 TNode SygusSymBreakNew::getFreeVar( TypeNode tn ) {
-  std::map< TypeNode, Node >::iterator it = d_free_var.find( tn );
-  if( it==d_free_var.end() ){
-    Node x = NodeManager::currentNM()->mkSkolem( "x", tn );
-    d_free_var[tn] = x;
-    return x;
-  }else{
-    return it->second;
-  }
+  return d_tds->getFreeVar(tn, 0);
 }
 
 unsigned SygusSymBreakNew::processSelectorChain( Node n, std::map< TypeNode, Node >& top_level, std::map< Node, unsigned >& tdepth, std::vector< Node >& lemmas ) {
@@ -741,15 +734,11 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
     Trace("sygus-sb-debug") << "  ......rewrites to " << bvr << std::endl;
     Trace("dt-sygus") << "  * DT builtin : " << n << " -> " << bvr << std::endl;
     unsigned sz = d_tds->getSygusTermSize( nv );      
-    std::vector< Node > exp;
-    bool do_exclude = false;
     if( d_tds->involvesDivByZero( bvr ) ){
-      Node x = getFreeVar( tn );
       quantifiers::DivByZeroSygusInvarianceTest dbzet;
       Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in " << d_tds->sygusToBuiltin( nv ) << std::endl;
-      d_tds->getExplain()->getExplanationFor(
-          x, nv, exp, dbzet, Node::null(), sz);
-      do_exclude = true;
+      registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), lemmas);
+      return false;
     }else{
       std::map< Node, Node >::iterator itsv = d_cache[a].d_search_val[tn].find( bvr );
       Node bad_val_bvr;
@@ -880,43 +869,43 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
         // do analysis of the evaluation  FIXME: does not work (evaluation is non-constant)
         quantifiers::EquivSygusInvarianceTest eset;
         eset.init(d_tds, tn, aconj, a, bvr);
+
         Trace("sygus-sb-mexp-debug") << "Minimize explanation for eval[" << d_tds->sygusToBuiltin( bad_val ) << "] = " << bvr << std::endl;
-        d_tds->getExplain()->getExplanationFor(
-            x, bad_val, exp, eset, bad_val_o, sz);
-        do_exclude = true;
-      }
-    }
-    if( do_exclude ){
-      Node lem = exp.size()==1 ? exp[0] : NodeManager::currentNM()->mkNode( kind::AND, exp );
-      lem = lem.negate();
-      /*  add min type depth to size : TODO?
-      Assert( d_term_to_anchor.find( n )!=d_term_to_anchor.end() );
-      TypeNode atype = d_term_to_anchor[n].getType();
-      if( atype!=tn ){
-        unsigned min_type_depth = d_tds->getMinTypeDepth( atype, tn );
-        if( min_type_depth>0 ){
-          Trace("sygus-sb-exc") << "  ........min type depth for " << ((DatatypeType)tn.toType()).getDatatype().getName() << " in ";
-          Trace("sygus-sb-exc") << ((DatatypeType)atype.toType()).getDatatype().getName() << " is " << min_type_depth << std::endl;
-          sz = sz + min_type_depth;
-        }
+        registerSymBreakLemmaForValue(a, bad_val, eset, bad_val_o, lemmas);
+        return false;
       }
-      */
-      Trace("sygus-sb-exc") << "  ........exc lemma is " << lem << ", size = " << sz << std::endl;
-      registerSymBreakLemma( tn, lem, sz, a, lemmas );
-      Trace("dt-sygus")
-          << "  ...excluded by dynamic symmetry breaking, based on " << n
-          << " == " << bvr << std::endl;
-      return false;
     }
   }
   return true;
 }
 
-
+void SygusSymBreakNew::registerSymBreakLemmaForValue(
+    Node a,
+    Node val,
+    quantifiers::SygusInvarianceTest& et,
+    Node valr,
+    std::vector<Node>& lemmas)
+{
+  TypeNode tn = val.getType();
+  Node x = getFreeVar(tn);
+  unsigned sz = d_tds->getSygusTermSize(val);
+  std::vector<Node> exp;
+  d_tds->getExplain()->getExplanationFor(x, val, exp, et, valr, sz);
+  Node lem =
+      exp.size() == 1 ? exp[0] : NodeManager::currentNM()->mkNode(AND, exp);
+  lem = lem.negate();
+  Trace("sygus-sb-exc") << "  ........exc lemma is " << lem << ", size = " << sz
+                        << std::endl;
+  registerSymBreakLemma(tn, lem, sz, a, lemmas);
+}
 
 void SygusSymBreakNew::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Node a, std::vector< Node >& lemmas ) {
   // lem holds for all terms of type tn, and is applicable to terms of size sz
-  Trace("sygus-sb-debug") << "  register sym break lemma : " << lem << ", size " << sz << std::endl;
+  Trace("sygus-sb-debug") << "  register sym break lemma : " << lem
+                          << std::endl;
+  Trace("sygus-sb-debug") << "     anchor : " << a << std::endl;
+  Trace("sygus-sb-debug") << "     type : " << tn << std::endl;
+  Trace("sygus-sb-debug") << "     size : " << sz << std::endl;
   Assert( !a.isNull() );
   d_cache[a].d_sb_lemmas[tn][sz].push_back( lem );
   TNode x = getFreeVar( tn );
@@ -928,7 +917,7 @@ void SygusSymBreakNew::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz
       for( unsigned k=0; k<itt->second.size(); k++ ){
         TNode t = itt->second[k];  
         if( !options::sygusSymBreakLazy() || d_active_terms.find( t )!=d_active_terms.end() ){
-          addSymBreakLemma( tn, lem, x, t, sz, d, lemmas );
+          addSymBreakLemma(lem, x, t, lemmas);
         }
       }
     }
@@ -953,14 +942,18 @@ void SygusSymBreakNew::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, No
       if( (int)it->first<=max_sz ){
         for( unsigned k=0; k<it->second.size(); k++ ){
           Node lem = it->second[k];
-          addSymBreakLemma( tn, lem, x, t, it->first, d, lemmas );
+          addSymBreakLemma(lem, x, t, lemmas);
         }
       }
     }
   }
 }
 
-void SygusSymBreakNew::addSymBreakLemma( TypeNode tn, Node lem, TNode x, TNode n, unsigned lem_sz, unsigned n_depth, std::vector< Node >& lemmas ) {
+void SygusSymBreakNew::addSymBreakLemma(Node lem,
+                                        TNode x,
+                                        TNode n,
+                                        std::vector<Node>& lemmas)
+{
   Assert( !options::sygusSymBreakLazy() || d_active_terms.find( n )!=d_active_terms.end() );
   // apply lemma
   Node slem = lem.substitute( x, n );
@@ -1124,7 +1117,7 @@ void SygusSymBreakNew::incrementCurrentSearchSize( Node m, std::vector< Node >&
               if( !options::sygusSymBreakLazy() || d_active_terms.find( t )!=d_active_terms.end() ){
                 for( unsigned j=0; j<it->second.size(); j++ ){
                   Node lem = it->second[j];
-                  addSymBreakLemma( tn, lem, x, t, sz, new_depth, lemmas );
+                  addSymBreakLemma(lem, x, t, lemmas);
                 }
               }
             }
@@ -1137,6 +1130,30 @@ void SygusSymBreakNew::incrementCurrentSearchSize( Node m, std::vector< Node >&
 
 void SygusSymBreakNew::check( std::vector< Node >& lemmas ) {
   Trace("sygus-sb") << "SygusSymBreakNew::check" << std::endl;
+
+  // check for externally registered symmetry breaking lemmas
+  std::vector<Node> anchors;
+  if (d_tds->hasSymBreakLemmas(anchors))
+  {
+    for (const Node& a : anchors)
+    {
+      std::vector<Node> sbl;
+      d_tds->getSymBreakLemmas(a, sbl);
+      for (const Node& lem : sbl)
+      {
+        TypeNode tn = d_tds->getTypeForSymBreakLemma(lem);
+        unsigned sz = d_tds->getSizeForSymBreakLemma(lem);
+        registerSymBreakLemma(tn, lem, sz, a, lemmas);
+      }
+    }
+    d_tds->clearSymBreakLemmas();
+    if (!lemmas.empty())
+    {
+      return;
+    }
+  }
+
+  // register search values, add symmetry breaking lemmas if applicable
   for( std::map< Node, bool >::iterator it = d_register_st.begin(); it != d_register_st.end(); ++it ){
     if( it->second ){
       Node prog = it->first;
index fa3918270c28d5fb388081bd3e3a0ecea62ab791..cb7729658286e89e8f7562c2e701c6d185a549f5 100644 (file)
@@ -30,6 +30,7 @@
 #include "expr/datatype.h"
 #include "expr/node.h"
 #include "theory/quantifiers/sygus/ce_guided_conjecture.h"
+#include "theory/quantifiers/sygus/sygus_explain.h"
 #include "theory/quantifiers/sygus_sampler.h"
 #include "theory/quantifiers/term_database.h"
 
@@ -41,13 +42,14 @@ class TheoryDatatypes;
 
 class SygusSymBreakNew
 {
-private:
-  TheoryDatatypes * d_td;
-  quantifiers::TermDbSygus * d_tds;
   typedef context::CDHashMap< Node, int, NodeHashFunction > IntMap;
   typedef context::CDHashMap< Node, Node, NodeHashFunction > NodeMap;
   typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
+
+ private:
+  TheoryDatatypes* d_td;
+  quantifiers::TermDbSygus* d_tds;
   IntMap d_testers;
   IntMap d_is_const;
   NodeMap d_testers_exp;
@@ -86,7 +88,10 @@ private:
    *
    */
   std::map< Node, bool > d_is_top_level;
-  void registerTerm( Node n, std::vector< Node >& lemmas );
+  /**
+   * Returns true if the selector chain n is top-level based on the above
+   * definition, when tn is the type of n.
+   */
   bool computeTopLevel( TypeNode tn, Node n );
 private:
   //list of all terms encountered in search at depth
@@ -117,7 +122,7 @@ private:
     /** For each term, whether this cache has processed that term */
     std::map< Node, bool > d_search_val_proc;
   };
-  // anchor -> cache
+  /** An instance of the above cache, for each anchor */
   std::map< Node, SearchCache > d_cache;
   /** a sygus sampler object for each (anchor, sygus type) pair
    *
@@ -125,21 +130,147 @@ private:
    * the rewriter.
    */
   std::map<Node, std::map<TypeNode, quantifiers::SygusSampler>> d_sampler;
-  Node d_null;
+  /** Assert tester internal
+   *
+   * This function is called when the tester with index tindex is asserted for
+   * n, exp is the tester predicate. For example, for grammar:
+   *   A -> A+A | x | 1 | 0
+   * when is_+( d ) is asserted,
+   * assertTesterInternal(0, s( d ), is_+( s( d ) ),...) is called. This
+   * function may add lemmas to lemmas, which are sent out on the output
+   * channel of datatypes by the caller.
+   *
+   * These lemmas are of various forms, including:
+   * (1) dynamic symmetry breaking clauses for subterms of n (those added to
+   * lemmas on calls to addSymBreakLemmasFor, see function below),
+   * (2) static symmetry breaking clauses for subterms of n (those added to
+   * lemmas on getSimpleSymBreakPred, see function below),
+   * (3) conjecture-specific symmetry breaking lemmas, see
+   * CegConjecture::getSymmetryBreakingPredicate,
+   * (4) fairness conflicts if sygusFair() is SYGUS_FAIR_DIRECT, e.g.:
+   *    size( d ) <= 1 V ~is-C1( d ) V ~is-C2( d.1 )
+   * where C1 and C2 are non-nullary constructors.
+   */
   void assertTesterInternal( int tindex, TNode n, Node exp, std::vector< Node >& lemmas );
-  // register search term
+  /**
+   * This function is called when term n is registered to the theory of
+   * datatypes. It makes the appropriate call to registerSearchTerm below,
+   * if applicable.
+   */
+  void registerTerm(Node n, std::vector<Node>& lemmas);
+  /** Register search term
+   *
+   * This function is called when selector chain S_1( ... S_m( n ) ... ) is
+   * registered to the theory of datatypes, where tn is the type of n,
+   * d indicates the depth of n (the sum of weights of the selectors S_1...S_m),
+   * and topLevel is whether n is a top-level term (see d_is_top_level).
+   *
+   * The purpose of this function is to notify this class that symmetry breaking
+   * lemmas should be instantiated for n. Any symmetry breaking lemmas that
+   * are active for n (see description of addSymBreakLemmasFor) are added to
+   * lemmas in this call.
+   */
   void registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas );
+  /** Register search value
+   *
+   * This function is called when a selector chain n has been assigned a model
+   * value nv. This function calls itself recursively so that extensions of the
+   * selector chain n are registered with all the subterms of nv. For example,
+   * if we call this function with:
+   *   n = x, nv = +( 1(), x() )
+   * we make recursive calls with:
+   *   n = x.1, nv = 1() and n = x.2, nv = x()
+   *
+   * a : the anchor of n,
+   * d : the depth of n.
+   *
+   * This function determines if the value nv is equivalent via rewriting to
+   * any previously registered search values for anchor a. If so, we construct
+   * a symmetry breaking lemma template and register it in d_cache[a]. For
+   * example, for grammar:
+   *   A -> A+A | x | 1 | 0
+   * Registering search value d -> x followed by d -> +( x, 0 ) results in the
+   * construction of the symmetry breaking lemma template:
+   *   ~is_+( z ) V ~is_x( z.1 ) V ~is_0( z.2 )
+   * which is stored in d_cache[a].d_sb_lemmas. This lemma is instantiated with
+   * z -> t for all terms t of appropriate depth, including d.
+   * This function strengthens blocking clauses using generalization techniques
+   * described in Reynolds et al SYNT 2017.
+   */
   bool registerSearchValue( Node a, Node n, Node nv, unsigned d, std::vector< Node >& lemmas );
-  void registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Node e, std::vector< Node >& lemmas );
-  void addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node e, std::vector< Node >& lemmas );
+  /** Register symmetry breaking lemma
+   *
+   * This function adds the symmetry breaking lemma template lem for terms of
+   * type tn with anchor a. This is added to d_cache[a].d_sb_lemmas. Notice that
+   * we use lem as a template with free variable x, e.g. our template is:
+   *   (lambda ((x tn)) lem)
+   * where x = getFreeVar( tn ). For all search terms t of the appropriate
+   * depth,
+   * we add the lemma lem{ x -> t } to lemmas.
+   *
+   * The argument sz indicates the size of terms that the lemma applies to, e.g.
+   *   ~is_+( z ) has size 1
+   *   ~is_+( z ) V ~is_x( z.1 ) V ~is_0( z.2 ) has size 1
+   *   ~is_+( z ) V ~is_+( z.1 ) has size 2
+   * This is equivalent to sum of weights of constructors corresponding to each
+   * tester, e.g. above + has weight 1, and x and 0 have weight 0.
+   */
+  void registerSymBreakLemma(
+      TypeNode tn, Node lem, unsigned sz, Node a, std::vector<Node>& lemmas);
+  /** Register symmetry breaking lemma for value
+   *
+   * This function adds a symmetry breaking lemma template for selector chains
+   * with anchor a, that effectively states that val should never be a subterm
+   * of any value for a.
+   *
+   * et : an "invariance test" (see sygus/sygus_invariance.h) which states a
+   * criterion that val meets, which is the reason for its exclusion. This is
+   * used for generalizing the symmetry breaking lemma template.
+   * valr : if non-null, this states a value that should *not* be excluded by
+   * the symmetry breaking lemma template, which is a restriction to the above
+   * generalization.
+   *
+   * This function may add instances of the symmetry breaking template for
+   * existing search terms, which are added to lemmas.
+   */
+  void registerSymBreakLemmaForValue(Node a,
+                                     Node val,
+                                     quantifiers::SygusInvarianceTest& et,
+                                     Node valr,
+                                     std::vector<Node>& lemmas);
+  /** Add symmetry breaking lemmas for term
+   *
+   * Adds all active symmetry breaking lemmas for selector chain t to lemmas. A
+   * symmetry breaking lemma L is active for t based on three factors:
+   * (1) the current search size sz(a) for its anchor a,
+   * (2) the depth d of term t (see d_term_to_depth),
+   * (3) the size sz(L) of the symmetry breaking lemma L.
+   * In particular, L is active if sz(L) <= sz(a) - d. In other words, a
+   * symmetry breaking lemma is active if it is intended to block terms of
+   * size sz(L), and the maximum size that t can take in the current search,
+   * sz(a)-d, is greater than or equal to this value.
+   *
+   * tn : the type of term t,
+   * a : the anchor of term t,
+   * d : the depth of term t.
+   */
+  void addSymBreakLemmasFor(
+      TypeNode tn, Node t, unsigned d, Node a, std::vector<Node>& lemmas);
+  /** calls the above function where a is the anchor t */
   void addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, std::vector< Node >& lemmas );
-  void addSymBreakLemma( TypeNode tn, Node lem, TNode x, TNode n, unsigned lem_sz, unsigned n_depth, std::vector< Node >& lemmas );
-private:
+  /** add symmetry breaking lemma
+   *
+   * This adds the lemma R => lem{ x -> n } to lemmas, where R is a "relevancy
+   * condition" that states which contexts n is relevant in (contexts in which
+   * the selector chain n is specified).
+   */
+  void addSymBreakLemma(Node lem, TNode x, TNode n, std::vector<Node>& lemmas);
+
+ private:
   std::map< Node, Node > d_rlv_cond;
   Node getRelevancyCondition( Node n );
 private:
   std::map< TypeNode, std::map< int, std::map< unsigned, Node > > > d_simple_sb_pred;
-  std::map< TypeNode, Node > d_free_var;
   // user-context dependent if sygus-incremental
   std::map< Node, unsigned > d_simple_proc;
   //get simple symmetry breaking predicate
index 1dd4dcbebc66e7398a93f31827fe2640c39da6e5..2273db5ea44e39754bfbd536332be793e2205495 100644 (file)
@@ -547,9 +547,8 @@ void CegConjecture::printAndContinueStream()
     {
       sol = d_cinfo[cprog].d_inst.back();
       // add to explanation of exclusion
-      d_qe->getTermDatabaseSygus()
-          ->getExplain()
-          ->getExplanationForConstantEquality(cprog, sol, exp);
+      d_qe->getTermDatabaseSygus()->getExplain()->getExplanationForEquality(
+          cprog, sol, exp);
     }
   }
   Assert(!exp.empty());
@@ -612,6 +611,8 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
         if (eq_sol != sol)
         {
           ++(cei->d_statistics.d_candidate_rewrites);
+          // if eq_sol is null, then we have an uninteresting candidate rewrite,
+          // e.g. one that is alpha-equivalent to another.
           if (!eq_sol.isNull())
           {
             // The analog of terms sol and eq_sol are equivalent under sample
index aafaa07e1eca22663d27fa7c332e0e9f134d971e..f76edb1c3ce0cee43fd6c61a9dc7ab6d0fb888d1 100644 (file)
@@ -110,19 +110,25 @@ Node TermRecBuild::build(unsigned d)
   return NodeManager::currentNM()->mkNode(d_kind[d], children);
 }
 
-void SygusExplain::getExplanationForConstantEquality(Node n,
-                                                     Node vn,
-                                                     std::vector<Node>& exp)
+void SygusExplain::getExplanationForEquality(Node n,
+                                             Node vn,
+                                             std::vector<Node>& exp)
 {
   std::map<unsigned, bool> cexc;
-  getExplanationForConstantEquality(n, vn, exp, cexc);
+  getExplanationForEquality(n, vn, exp, cexc);
 }
 
-void SygusExplain::getExplanationForConstantEquality(
-    Node n, Node vn, std::vector<Node>& exp, std::map<unsigned, bool>& cexc)
+void SygusExplain::getExplanationForEquality(Node n,
+                                             Node vn,
+                                             std::vector<Node>& exp,
+                                             std::map<unsigned, bool>& cexc)
 {
-  Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR);
   Assert(n.getType() == vn.getType());
+  if (n == vn)
+  {
+    return;
+  }
+  Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR);
   TypeNode tn = n.getType();
   Assert(tn.isDatatype());
   const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype();
@@ -137,22 +143,23 @@ void SygusExplain::getExplanationForConstantEquality(
           kind::APPLY_SELECTOR_TOTAL,
           Node::fromExpr(dt[i].getSelectorInternal(tn.toType(), j)),
           n);
-      getExplanationForConstantEquality(sel, vn[j], exp);
+      getExplanationForEquality(sel, vn[j], exp);
     }
   }
 }
 
-Node SygusExplain::getExplanationForConstantEquality(Node n, Node vn)
+Node SygusExplain::getExplanationForEquality(Node n, Node vn)
 {
   std::map<unsigned, bool> cexc;
-  return getExplanationForConstantEquality(n, vn, cexc);
+  return getExplanationForEquality(n, vn, cexc);
 }
 
-Node SygusExplain::getExplanationForConstantEquality(
-    Node n, Node vn, std::map<unsigned, bool>& cexc)
+Node SygusExplain::getExplanationForEquality(Node n,
+                                             Node vn,
+                                             std::map<unsigned, bool>& cexc)
 {
   std::vector<Node> exp;
-  getExplanationForConstantEquality(n, vn, exp, cexc);
+  getExplanationForEquality(n, vn, exp, cexc);
   Assert(!exp.empty());
   return exp.size() == 1 ? exp[0]
                          : NodeManager::currentNM()->mkNode(kind::AND, exp);
@@ -250,7 +257,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb,
       // if excluded, we may need to add the explanation for this
       if (vnr_exp.isNull() && !vnr_c.isNull())
       {
-        vnr_exp = getExplanationForConstantEquality(sel, vnr[i]);
+        vnr_exp = getExplanationForEquality(sel, vnr[i]);
       }
     }
   }
@@ -264,7 +271,7 @@ void SygusExplain::getExplanationFor(Node n,
                                      unsigned& sz)
 {
   // naive :
-  // return getExplanationForConstantEquality( n, vn, exp );
+  // return getExplanationForEquality( n, vn, exp );
 
   // set up the recursion object
   std::map<TypeNode, int> var_count;
index ad26f29e49f854018b8c4dda62f0b67440c008e7..818f514384c8b78a07aaa211beedb504d86a1a3a 100644 (file)
@@ -100,7 +100,7 @@ class TermRecBuild
  * (datatype) sygus term n is:
  *  (if (gt x 0) 0 0)
  * where if, gt, x, 0 are datatype constructors.
- * The explanation returned by getExplanationForConstantEquality
+ * The explanation returned by getExplanationForEquality
  * below for n and the above term is:
  *   { ((_ is if) n), ((_ is geq) n.0),
  *     ((_ is x) n.0.0), ((_ is 0) n.0.1),
@@ -142,20 +142,19 @@ class SygusExplain
  public:
   SygusExplain(TermDbSygus* tdb) : d_tdb(tdb) {}
   ~SygusExplain() {}
-  /** get explanation for constant equality
+  /** get explanation for equality
    *
    * This function constructs an explanation, stored in exp, such that:
    * - All formulas in exp are of the form ((_ is C) ns), where ns
    *   is a chain of selectors applied to n, and
    * - exp => ( n = vn )
    */
-  void getExplanationForConstantEquality(Node n,
-                                         Node vn,
-                                         std::vector<Node>& exp);
+  void getExplanationForEquality(Node n, Node vn, std::vector<Node>& exp);
   /** returns the conjunction of exp computed in the above function */
-  Node getExplanationForConstantEquality(Node n, Node vn);
+  Node getExplanationForEquality(Node n, Node vn);
 
-  /** get explanation for constant equality
+  /** get explanation for equality
+   *
    * This is identical to the above function except that we
    * take an additional argument cexc, which says which
    * children of vn should be excluded from the explanation.
@@ -165,14 +164,14 @@ class SygusExplain
    *   { ((_ is plus) n), ((_ is y) n.1) }
    * where notice that the 0^th argument of vn is excluded.
    */
-  void getExplanationForConstantEquality(Node n,
-                                         Node vn,
-                                         std::vector<Node>& exp,
-                                         std::map<unsigned, bool>& cexc);
+  void getExplanationForEquality(Node n,
+                                 Node vn,
+                                 std::vector<Node>& exp,
+                                 std::map<unsigned, bool>& cexc);
   /** returns the conjunction of exp computed in the above function */
-  Node getExplanationForConstantEquality(Node n,
-                                         Node vn,
-                                         std::map<unsigned, bool>& cexc);
+  Node getExplanationForEquality(Node n,
+                                 Node vn,
+                                 std::map<unsigned, bool>& cexc);
 
   /** get explanation for
    *
index 36e8838483243427258b36de271051bf32b0ee24..1c61544e14bc4df59db4181e998af3f71f71d1bc 100644 (file)
@@ -1303,7 +1303,7 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >&
   if (exp_exc.isNull())
   {
     // if we did not already explain why this should be excluded, use default
-    exp_exc = d_tds->getExplain()->getExplanationForConstantEquality(x, v);
+    exp_exc = d_tds->getExplain()->getExplanationForEquality(x, v);
   }
   Node exlem =
       NodeManager::currentNM()->mkNode(OR, g.negate(), exp_exc.negate());
index b12a23c831b2d5399cd43827f0c7bae7298ec4df..e8bdf2083b0ba2b96e712aaae138d24db2e752d3 100644 (file)
@@ -733,6 +733,59 @@ void TermDbSygus::getEnumerators(std::vector<Node>& mts)
   }
 }
 
+void TermDbSygus::registerSymBreakLemma(Node e,
+                                        Node lem,
+                                        TypeNode tn,
+                                        unsigned sz)
+{
+  d_enum_to_sb_lemmas[e].push_back(lem);
+  d_sb_lemma_to_type[lem] = tn;
+  d_sb_lemma_to_size[lem] = sz;
+}
+
+bool TermDbSygus::hasSymBreakLemmas(std::vector<Node>& enums) const
+{
+  if (!d_enum_to_sb_lemmas.empty())
+  {
+    for (std::pair<const Node, std::vector<Node> > sb : d_enum_to_sb_lemmas)
+    {
+      enums.push_back(sb.first);
+    }
+    return true;
+  }
+  return false;
+}
+
+void TermDbSygus::getSymBreakLemmas(Node e, std::vector<Node>& lemmas) const
+{
+  std::map<Node, std::vector<Node> >::const_iterator itsb =
+      d_enum_to_sb_lemmas.find(e);
+  if (itsb != d_enum_to_sb_lemmas.end())
+  {
+    lemmas.insert(lemmas.end(), itsb->second.begin(), itsb->second.end());
+  }
+}
+
+TypeNode TermDbSygus::getTypeForSymBreakLemma(Node lem) const
+{
+  std::map<Node, TypeNode>::const_iterator it = d_sb_lemma_to_type.find(lem);
+  Assert(it != d_sb_lemma_to_type.end());
+  return it->second;
+}
+unsigned TermDbSygus::getSizeForSymBreakLemma(Node lem) const
+{
+  std::map<Node, unsigned>::const_iterator it = d_sb_lemma_to_size.find(lem);
+  Assert(it != d_sb_lemma_to_size.end());
+  return it->second;
+}
+
+void TermDbSygus::clearSymBreakLemmas()
+{
+  d_enum_to_sb_lemmas.clear();
+  d_sb_lemma_to_type.clear();
+  d_sb_lemma_to_size.clear();
+}
+
 bool TermDbSygus::isRegistered( TypeNode tn ) {
   return d_register.find( tn )!=d_register.end();
 }
@@ -1202,7 +1255,7 @@ void TermDbSygus::registerModelValue( Node a, Node v, std::vector< Node >& terms
         unsigned start = d_node_mv_args_proc[n][vn];
         // get explanation in terms of testers
         std::vector< Node > antec_exp;
-        d_syexp->getExplanationForConstantEquality(n, vn, antec_exp);
+        d_syexp->getExplanationForEquality(n, vn, antec_exp);
         Node antec = antec_exp.size()==1 ? antec_exp[0] : NodeManager::currentNM()->mkNode( kind::AND, antec_exp );
         //Node antec = n.eqNode( vn );
         TypeNode tn = n.getType();
index e796a3adcf158d146bfae1c4a6e0bbadef7d06fb..7ef9e6151b21087ca89fdf351fe9ec3c0768ac9c 100644 (file)
@@ -40,21 +40,31 @@ class TermDbSygus {
   std::string identify() const { return "TermDbSygus"; }
   /** register the sygus type */
   void registerSygusType(TypeNode tn);
-  /** register a variable e that we will do enumerative search on
-   * conj is the conjecture that the enumeration of e is for.
-   * f is the synth-fun that the enumeration of e is for.
-   * mkActiveGuard is whether we want to make an active guard for e 
+
+  //------------------------------utilities
+  /** get the explanation utility */
+  SygusExplain* getExplain() { return d_syexp.get(); }
+  /** get the extended rewrite utility */
+  ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); }
+  //------------------------------end utilities
+
+  //------------------------------enumerators
+  /**
+   * Register a variable e that we will do enumerative search on.
+   * conj : the conjecture that the enumeration of e is for.
+   * f : the synth-fun that the enumeration of e is for.
+   * mkActiveGuard : whether we want to make an active guard for e
    * (see d_enum_to_active_guard).
    *
-   * Notice that enumerator e may not be equivalent
-   * to f in synthesis-through-unification approaches
-   * (e.g. decision tree construction for PBE synthesis).
+   * Notice that enumerator e may not be one-to-one with f in
+   * synthesis-through-unification approaches (e.g. decision tree construction
+   * for PBE synthesis).
    */
   void registerEnumerator(Node e,
                           Node f,
                           CegConjecture* conj,
                           bool mkActiveGuard = false);
-  /** is e an enumerator? */
+  /** is e an enumerator registered with this class? */
   bool isEnumerator(Node e) const;
   /** return the conjecture e is associated with */
   CegConjecture* getConjectureForEnumerator(Node e);
@@ -64,10 +74,36 @@ class TermDbSygus {
   Node getActiveGuardForEnumerator(Node e);
   /** get all registered enumerators */
   void getEnumerators(std::vector<Node>& mts);
-  /** get the explanation utility */
-  SygusExplain* getExplain() { return d_syexp.get(); }
-  /** get the extended rewrite utility */
-  ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); }
+  /** Register symmetry breaking lemma
+   *
+   * This function registers lem as a symmetry breaking lemma template for
+   * subterms of enumerator e. For more information on symmetry breaking
+   * lemma templates, see datatypes/datatypes_sygus.h.
+   *
+   * tn : the (sygus datatype) type that lem applies to, i.e. the
+   * type of terms that lem blocks models for,
+   * sz : the minimum size of terms that the lem blocks.
+   *
+   * Notice that the symmetry breaking lemma template should be relative to x,
+   * where x is returned by the call to getFreeVar( tn, 0 ) in this class.
+   */
+  void registerSymBreakLemma(Node e, Node lem, TypeNode tn, unsigned sz);
+  /** Has symmetry breaking lemmas been added for any enumerator? */
+  bool hasSymBreakLemmas(std::vector<Node>& enums) const;
+  /** Get symmetry breaking lemmas
+   *
+   * Returns the set of symmetry breaking lemmas that have been registered
+   * for enumerator e. It adds these to lemmas.
+   */
+  void getSymBreakLemmas(Node e, std::vector<Node>& lemmas) const;
+  /** Get the type of term symmetry breaking lemma lem applies to */
+  TypeNode getTypeForSymBreakLemma(Node lem) const;
+  /** Get the minimum size of terms symmetry breaking lemma lem applies to */
+  unsigned getSizeForSymBreakLemma(Node lem) const;
+  /** Clear information about symmetry breaking lemmas */
+  void clearSymBreakLemmas();
+  //------------------------------end enumerators
+
   //-----------------------------conversion from sygus to builtin
   /** get free variable
    *
@@ -121,10 +157,15 @@ class TermDbSygus {
  private:
   /** reference to the quantifiers engine */
   QuantifiersEngine* d_quantEngine;
+
+  //------------------------------utilities
   /** sygus explanation */
   std::unique_ptr<SygusExplain> d_syexp;
   /** sygus explanation */
   std::unique_ptr<ExtendedRewriter> d_ext_rw;
+  //------------------------------end utilities
+
+  //------------------------------enumerators
   /** mapping from enumerator terms to the conjecture they are associated with
    */
   std::map<Node, CegConjecture*> d_enum_to_conjecture;
@@ -137,6 +178,13 @@ class TermDbSygus {
    *   if G is true, then there are more values of e to enumerate".
    */
   std::map<Node, Node> d_enum_to_active_guard;
+  /** mapping from enumerators to symmetry breaking clauses for them */
+  std::map<Node, std::vector<Node> > d_enum_to_sb_lemmas;
+  /** mapping from symmetry breaking lemmas to type */
+  std::map<Node, TypeNode> d_sb_lemma_to_type;
+  /** mapping from symmetry breaking lemmas to size */
+  std::map<Node, unsigned> d_sb_lemma_to_size;
+  //------------------------------end enumerators
 
   //-----------------------------conversion from sygus to builtin
   /** cache for sygusToBuiltin */