Minor refactor datatypes sygus (#1673)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 20 Mar 2018 22:32:43 +0000 (17:32 -0500)
committerMathias Preiner <mathias.preiner@gmail.com>
Tue, 20 Mar 2018 22:32:43 +0000 (15:32 -0700)
src/theory/datatypes/datatypes_sygus.cpp
src/theory/datatypes/datatypes_sygus.h
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h

index 1779ab27b4c089d57cb246847db7ac8fa211e0c1..3c90bc448cab3f3c88bc187d732748b3d15b94f7 100644 (file)
@@ -222,10 +222,10 @@ void SygusSymBreakNew::registerTerm( Node n, std::vector< Node >& lemmas ) {
     bool success = false;
     if( n.getKind()==kind::APPLY_SELECTOR_TOTAL ){
       registerTerm( n[0], lemmas );
-      std::map< Node, Node >::iterator it = d_term_to_anchor.find( n[0] );
+      std::unordered_map<Node, Node, NodeHashFunction>::iterator it =
+          d_term_to_anchor.find(n[0]);
       if( it!=d_term_to_anchor.end() ) {
         d_term_to_anchor[n] = it->second;
-        d_term_to_anchor_conj[n] = d_term_to_anchor_conj[n[0]];
         unsigned sel_weight =
             d_tds->getSelectorWeight(n[0].getType(), n.getOperator());
         d = d_term_to_depth[n[0]] + sel_weight;
@@ -236,9 +236,9 @@ void SygusSymBreakNew::registerTerm( Node n, std::vector< Node >& lemmas ) {
       registerSizeTerm( n, lemmas );
       if( d_register_st[n] ){
         d_term_to_anchor[n] = n;
-        d_term_to_anchor_conj[n] = d_tds->getConjectureForEnumerator(n);
+        d_anchor_to_conj[n] = d_tds->getConjectureForEnumerator(n);
         // this assertion fails if we have a sygus term in the search that is unmeasured
-        Assert(d_term_to_anchor_conj[n] != NULL);
+        Assert(d_anchor_to_conj[n] != NULL);
         d = 0;
         is_top_level = true;
         success = true;
@@ -354,8 +354,8 @@ void SygusSymBreakNew::assertTesterInternal( int tindex, TNode n, Node exp, std:
       }
       // static conjecture-dependent symmetry breaking
       std::map<Node, quantifiers::CegConjecture*>::iterator itc =
-          d_term_to_anchor_conj.find(n);
-      if (itc != d_term_to_anchor_conj.end())
+          d_anchor_to_conj.find(a);
+      if (itc != d_anchor_to_conj.end())
       {
         quantifiers::CegConjecture* conj = itc->second;
         Assert(conj != NULL);
@@ -691,7 +691,8 @@ unsigned SygusSymBreakNew::processSelectorChain( Node n, std::map< TypeNode, Nod
 
 void SygusSymBreakNew::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas ) {
   //register this term
-  std::map< Node, Node >::iterator ita = d_term_to_anchor.find( n );
+  std::unordered_map<Node, Node, NodeHashFunction>::iterator ita =
+      d_term_to_anchor.find(n);
   Assert( ita != d_term_to_anchor.end() );
   Node a = ita->second;
   Assert( !a.isNull() );
@@ -722,10 +723,10 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
   Trace("sygus-sb-debug2") << "Registering search value " << n << " -> " << nv << std::endl;
   // must do this for all nodes, regardless of top-level
   if( d_cache[a].d_search_val_proc.find( nv )==d_cache[a].d_search_val_proc.end() ){
-    d_cache[a].d_search_val_proc[nv] = true;
+    d_cache[a].d_search_val_proc.insert(nv);
     // get the root (for PBE symmetry breaking)
-    Assert(d_term_to_anchor_conj.find(a) != d_term_to_anchor_conj.end());
-    quantifiers::CegConjecture* aconj = d_term_to_anchor_conj[a];
+    Assert(d_anchor_to_conj.find(a) != d_anchor_to_conj.end());
+    quantifiers::CegConjecture* aconj = d_anchor_to_conj[a];
     Assert(aconj != NULL);
     Trace("sygus-sb-debug") << "  ...register search value " << nv << ", type=" << tn << std::endl;
     Node bv = d_tds->sygusToBuiltin( nv, tn );
@@ -740,7 +741,8 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
       registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), lemmas);
       return false;
     }else{
-      std::map< Node, Node >::iterator itsv = d_cache[a].d_search_val[tn].find( bvr );
+      std::unordered_map<Node, Node, NodeHashFunction>::iterator itsv =
+          d_cache[a].d_search_val[tn].find(bvr);
       Node bad_val_bvr;
       bool by_examples = false;
       if( itsv==d_cache[a].d_search_val[tn].end() ){
@@ -787,20 +789,9 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
               d_tds, nv, options::sygusSamples(), false);
           its = d_sampler[a].find(tn);
         }
-        Node bvr_sample_ret;
-        std::map<Node, Node>::iterator itsv =
-            d_cache[a].d_search_val_sample[tn].find(bvr);
-        if (itsv == d_cache[a].d_search_val_sample[tn].end())
-        {
-          // initialize the sampler for the rewritten form of this node
-          bvr_sample_ret = its->second.registerTerm(bvr);
-          d_cache[a].d_search_val_sample[tn][bvr] = bvr_sample_ret;
-        }
-        else
-        {
-          bvr_sample_ret = itsv->second;
-        }
 
+        // register the rewritten node with the sampler
+        Node bvr_sample_ret = its->second.registerTerm(bvr);
         // register the current node with the sampler
         Node sample_ret = its->second.registerTerm(bv);
 
@@ -1072,7 +1063,8 @@ void SygusSymBreakNew::notifySearchSize( Node m, unsigned s, Node exp, std::vect
 
 unsigned SygusSymBreakNew::getSearchSizeFor( Node n ) {
   Trace("sygus-sb-debug2") << "get search size for term : " << n << std::endl;
-  std::map< Node, Node >::iterator ita = d_term_to_anchor.find( n );
+  std::unordered_map<Node, Node, NodeHashFunction>::iterator ita =
+      d_term_to_anchor.find(n);
   Assert( ita != d_term_to_anchor.end() );
   return getSearchSizeForAnchor( ita->second );
 }
index cb7729658286e89e8f7562c2e701c6d185a549f5..2936c1561b8bed26c7474cf40687535c808b8e54 100644 (file)
@@ -47,8 +47,22 @@ class SygusSymBreakNew
   typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
 
+ public:
+  SygusSymBreakNew(TheoryDatatypes* td,
+                   quantifiers::TermDbSygus* tds,
+                   context::Context* c);
+  ~SygusSymBreakNew();
+  /** add tester */
+  void assertTester(int tindex, TNode n, Node exp, std::vector<Node>& lemmas);
+  void assertFact(Node n, bool polarity, std::vector<Node>& lemmas);
+  void preRegisterTerm(TNode n, std::vector<Node>& lemmas);
+  void check(std::vector<Node>& lemmas);
+  Node getNextDecisionRequest(unsigned& priority, std::vector<Node>& lemmas);
+
  private:
+  /** Pointer to the datatype theory that owns this class. */
   TheoryDatatypes* d_td;
+  /** Pointer to the sygus term database */
   quantifiers::TermDbSygus* d_tds;
   IntMap d_testers;
   IntMap d_is_const;
@@ -56,18 +70,15 @@ class SygusSymBreakNew
   NodeSet d_active_terms;
   IntMap d_currTermSize;
   Node d_zero;
-
- private:
   /**
    * Map from terms (selector chains) to their anchors. The anchor of a
    * selector chain S1( ... Sn( x ) ... ) is x.
    */
-  std::map< Node, Node > d_term_to_anchor;
+  std::unordered_map<Node, Node, NodeHashFunction> d_term_to_anchor;
   /**
-   * Map from terms (selector chains) to the conjecture that their anchor is
-   * associated with.
+   * Map from anchors to the conjecture they are associated with.
    */
-  std::map<Node, quantifiers::CegConjecture*> d_term_to_anchor_conj;
+  std::map<Node, quantifiers::CegConjecture*> d_anchor_to_conj;
   /**
    * Map from terms (selector chains) to their depth. The depth of a selector
    * chain S1( ... Sn( x ) ... ) is:
@@ -75,7 +86,7 @@ class SygusSymBreakNew
    * where weight is the selector weight of Si
    * (see SygusTermDatabase::getSelectorWeight).
    */
-  std::map< Node, unsigned > d_term_to_depth;
+  std::unordered_map<Node, unsigned, NodeHashFunction> d_term_to_depth;
   /**
    * Map from terms (selector chains) to whether they are the topmost term
    * of their type. For example, if:
@@ -87,18 +98,24 @@ class SygusSymBreakNew
    * whereas S2( S1( x ) ) and S3( S2( S1( x ) ) ) are not.
    *
    */
-  std::map< Node, bool > d_is_top_level;
+  std::unordered_map<Node, bool, NodeHashFunction> d_is_top_level;
   /**
    * Returns true if the selector chain n is top-level based on the above
    * definition, when tn is the type of n.
    */
   bool computeTopLevel( TypeNode tn, Node n );
 private:
-  //list of all terms encountered in search at depth
-  class SearchCache {
+ /** This caches all information regarding symmetry breaking for an anchor. */
+ class SearchCache
+ {
   public:
     SearchCache(){}
+    /**
+     * A cache of all search terms for (types, sizes). See registerSearchTerm
+     * for definition of search terms.
+     */
     std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_search_terms;
+    /** A cache of all symmetry breaking lemma templates for (types, sizes). */
     std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_sb_lemmas;
     /** search value
      *
@@ -107,20 +124,13 @@ private:
      * term. The range of this map can be updated if we later encounter a sygus
      * term that also rewrites to the builtin value but has a smaller term size.
      */
-    std::map< TypeNode, std::map< Node, Node > > d_search_val;
+    std::map<TypeNode, std::unordered_map<Node, Node, NodeHashFunction>>
+        d_search_val;
     /** the size of terms in the range of d_search val. */
-    std::map< TypeNode, std::map< Node, unsigned > > d_search_val_sz;
-    /** search value sample
-     *
-     * This is used for the sygusRewVerify() option. For each sygus term t
-     * of type tn with anchor a that we register with this cache, we set:
-     *   d_search_val_sample[tn][r] = r'
-     * where r is the rewritten form of the builtin equivalent of t, and r'
-     * is the term returned by d_sampler[a][tn].registerTerm( r ).
-     */
-    std::map<TypeNode, std::map<Node, Node>> d_search_val_sample;
+    std::map<TypeNode, std::unordered_map<Node, unsigned, NodeHashFunction>>
+        d_search_val_sz;
     /** For each term, whether this cache has processed that term */
-    std::map< Node, bool > d_search_val_proc;
+    std::unordered_set<Node, NodeHashFunction> d_search_val_proc;
   };
   /** An instance of the above cache, for each anchor */
   std::map< Node, SearchCache > d_cache;
@@ -158,12 +168,15 @@ private:
    * if applicable.
    */
   void registerTerm(Node n, std::vector<Node>& lemmas);
+
+  //------------------------dynamic symmetry breaking
   /** Register search term
    *
-   * This function is called when selector chain S_1( ... S_m( n ) ... ) is
-   * registered to the theory of datatypes, where tn is the type of n,
-   * d indicates the depth of n (the sum of weights of the selectors S_1...S_m),
-   * and topLevel is whether n is a top-level term (see d_is_top_level).
+   * This function is called when selector chain n of the form
+   * S_1( ... S_m( x ) ... ) is registered to the theory of datatypes, where
+   * tn is the type of n, d indicates the depth of n (the sum of weights of the
+   * selectors S_1...S_m), and topLevel is whether n is a top-level term
+   * (see d_is_top_level). We refer to n as a "search term".
    *
    * The purpose of this function is to notify this class that symmetry breaking
    * lemmas should be instantiated for n. Any symmetry breaking lemmas that
@@ -261,27 +274,74 @@ private:
   /** add symmetry breaking lemma
    *
    * This adds the lemma R => lem{ x -> n } to lemmas, where R is a "relevancy
-   * condition" that states which contexts n is relevant in (contexts in which
-   * the selector chain n is specified).
+   * condition" that states which contexts n is relevant in (see
+   * getRelevancyCondition).
    */
   void addSymBreakLemma(Node lem, TNode x, TNode n, std::vector<Node>& lemmas);
+  //------------------------end dynamic symmetry breaking
 
- private:
-  std::map< Node, Node > d_rlv_cond;
+  /** Get relevancy condition
+   *
+   * This returns a predicate that holds in the contexts in which the selector
+   * chain n is specified. For example, the relevancy condition for
+   * sel_{C2,1}( sel_{C1,1}( d ) ) is is-C1( d ) ^ is-C2( sel_{C1,1}( d ) ).
+   * If shared selectors are enabled, this is a conjunction of disjunctions,
+   * since shared selectors may apply to multiple constructors.
+   */
   Node getRelevancyCondition( Node n );
-private:
-  std::map< TypeNode, std::map< int, std::map< unsigned, Node > > > d_simple_sb_pred;
-  // user-context dependent if sygus-incremental
-  std::map< Node, unsigned > d_simple_proc;
-  //get simple symmetry breaking predicate
+  /** Cache of the above function */
+  std::map<Node, Node> d_rlv_cond;
+
+  //------------------------static symmetry breaking
+  /** Get simple symmetry breakind predicate
+   *
+   * This function returns the "static" symmetry breaking lemma template for
+   * terms with type tn and constructor index tindex, for the given depth. This
+   * includes inferences about size with depth=0. Given grammar:
+   *   A -> ite( B, A, A ) | A+A | x | 1 | 0
+   *   B -> A = A
+   * Examples of static symmetry breaking lemma templates are:
+   *   for +, depth 0: size(z)=size(z.1)+size(z.2)+1
+   *   for +, depth 1: ~is-0( z.1 ) ^ ~is-0( z.2 ) ^ F
+   *     where F ensures the constructor of z.1 is less than that of z.2 based
+   *     on some ordering.
+   *   for ite, depth 1: z.2 != z.3
+   * These templates can be thought of as "hard-coded" cases of dynamic symmetry
+   * breaking lemma templates. Notice that the above lemma templates are in
+   * terms of getFreeVar( tn ), hence only one is created per
+   * (constructor, depth). A static symmetry break lemma template F[z] for
+   * constructor C are included in lemmas of the form:
+   *   is-C( t ) => F[t]
+   * where t is a search term, see registerSearchTerm for definition of search
+   * term.
+   */
   Node getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned depth );
+  /** Cache of the above function */
+  std::map<TypeNode, std::map<int, std::map<unsigned, Node>>> d_simple_sb_pred;
+  /**
+   * For each search term, this stores the maximum depth for which we have added
+   * a static symmetry breaking lemma.
+   *
+   * This should be user context-dependent if sygus is updated to work in
+   * incremental mode.
+   */
+  std::unordered_map<Node, unsigned, NodeHashFunction> d_simple_proc;
+  //------------------------end static symmetry breaking
+
+  /** Get the canonical free variable for type tn */
   TNode getFreeVar( TypeNode tn );
   Node getTermOrderPredicate( Node n1, Node n2 );
 private:
-  //should be user-context dependent if sygus in incremental mode
-  std::map< Node, bool > d_register_st;
-  void registerSizeTerm( Node e, std::vector< Node >& lemmas );
-  class SearchSizeInfo {
+ /**
+  * Map from registered variables to whether they are a sygus enumerator.
+  *
+  * This should be user context-dependent if sygus is updated to work in
+  * incremental mode.
+  */
+ std::map<Node, bool> d_register_st;
+ void registerSizeTerm(Node e, std::vector<Node>& lemmas);
+ class SearchSizeInfo
+ {
   public:
     SearchSizeInfo( Node t, context::Context* c ) : d_this( t ), d_curr_search_size(0), d_curr_lit( c, 0 ) {}
     Node d_this;
@@ -323,16 +383,6 @@ private:
   int getGuardStatus( Node g );
 private:
   void assertIsConst( Node n, bool polarity, std::vector< Node >& lemmas );
-public:
-  SygusSymBreakNew( TheoryDatatypes * td, quantifiers::TermDbSygus * tds, context::Context* c );
-  ~SygusSymBreakNew();
-  /** add tester */
-  void assertTester( int tindex, TNode n, Node exp, std::vector< Node >& lemmas );
-  void assertFact( Node n, bool polarity, std::vector< Node >& lemmas );
-  void preRegisterTerm( TNode n, std::vector< Node >& lemmas  );
-  void check( std::vector< Node >& lemmas );
-public:
-  Node getNextDecisionRequest( unsigned& priority, std::vector< Node >& lemmas );
 };
 
 }
index e8bdf2083b0ba2b96e712aaae138d24db2e752d3..40183fe9c6cc9d36e41b09e80468845f4ac81038 100644 (file)
@@ -161,42 +161,53 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, int c, std::map<int, Node>& pre)
   return mkGeneric(dt, c, var_count, pre);
 }
 
+struct SygusToBuiltinAttributeId
+{
+};
+typedef expr::Attribute<SygusToBuiltinAttributeId, Node>
+    SygusToBuiltinAttribute;
+
 Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) {
   Assert( n.getType()==tn );
   Assert( tn.isDatatype() );
-  std::map< Node, Node >::iterator it = d_sygus_to_builtin[tn].find( n );
-  if( it==d_sygus_to_builtin[tn].end() ){
-    Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n << ", type = " << tn << std::endl;
-    const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
-    if( n.getKind()==APPLY_CONSTRUCTOR ){
-      unsigned i = Datatype::indexOf( n.getOperator().toExpr() );
-      Assert( n.getNumChildren()==dt[i].getNumArgs() );
-      std::map< TypeNode, int > var_count;
-      std::map< int, Node > pre;
-      for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
-      {
-        pre[j] = sygusToBuiltin( n[j], getArgType( dt[i], j ) );
-      }
-      Node ret = mkGeneric(dt, i, var_count, pre);
-      Trace("sygus-db-debug") << "SygusToBuiltin : Generic is " << ret << std::endl;
-      d_sygus_to_builtin[tn][n] = ret;
-      return ret;
-    }
-    if (n.hasAttribute(SygusPrintProxyAttribute()))
+
+  // has it already been computed?
+  if (n.hasAttribute(SygusToBuiltinAttribute()))
+  {
+    return n.getAttribute(SygusToBuiltinAttribute());
+  }
+
+  Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n
+                          << ", type = " << tn << std::endl;
+  const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
+  if (n.getKind() == APPLY_CONSTRUCTOR)
+  {
+    unsigned i = Datatype::indexOf(n.getOperator().toExpr());
+    Assert(n.getNumChildren() == dt[i].getNumArgs());
+    std::map<TypeNode, int> var_count;
+    std::map<int, Node> pre;
+    for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
     {
-      // this variable was associated by an attribute to a builtin node
-      return n.getAttribute(SygusPrintProxyAttribute());
+      pre[j] = sygusToBuiltin(n[j], getArgType(dt[i], j));
     }
-    Assert(isFreeVar(n));
-    // map to builtin variable type
-    int fv_num = getVarNum(n);
-    Assert(!dt.getSygusType().isNull());
-    TypeNode vtn = TypeNode::fromType(dt.getSygusType());
-    Node ret = getFreeVar(vtn, fv_num);
+    Node ret = mkGeneric(dt, i, var_count, pre);
+    Trace("sygus-db-debug")
+        << "SygusToBuiltin : Generic is " << ret << std::endl;
+    n.setAttribute(SygusToBuiltinAttribute(), ret);
     return ret;
-  }else{
-    return it->second;
   }
+  if (n.hasAttribute(SygusPrintProxyAttribute()))
+  {
+    // this variable was associated by an attribute to a builtin node
+    return n.getAttribute(SygusPrintProxyAttribute());
+  }
+  Assert(isFreeVar(n));
+  // map to builtin variable type
+  int fv_num = getVarNum(n);
+  Assert(!dt.getSygusType().isNull());
+  TypeNode vtn = TypeNode::fromType(dt.getSygusType());
+  Node ret = getFreeVar(vtn, fv_num);
+  return ret;
 }
 
 Node TermDbSygus::sygusSubstituted( TypeNode tn, Node n, std::vector< Node >& args ) {
index 7ef9e6151b21087ca89fdf351fe9ec3c0768ac9c..57a127d8d946b94296395d34c6c5ce0f58af9a56 100644 (file)
@@ -187,8 +187,6 @@ class TermDbSygus {
   //------------------------------end enumerators
 
   //-----------------------------conversion from sygus to builtin
-  /** cache for sygusToBuiltin */
-  std::map<TypeNode, std::map<Node, Node> > d_sygus_to_builtin;
   /** a cache of fresh variables for each type
    *
    * We store two versions of this list: