Fix symmetry breaking for multiple sygus types (#3775)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 19 Feb 2020 23:14:01 +0000 (17:14 -0600)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2020 23:14:01 +0000 (17:14 -0600)
src/theory/datatypes/sygus_extension.cpp
src/theory/quantifiers/sygus/example_eval_cache.cpp
src/theory/quantifiers/sygus/example_eval_cache.h
src/theory/quantifiers/sygus/sygus_enumerator.cpp

index 1e593a664588dedba6007de8118a29476c0d14b5..95b73b2fec5f799ee11d508e3a10cf0230e6590f 100644 (file)
@@ -952,9 +952,13 @@ void SygusExtension::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool t
   Assert(ita != d_term_to_anchor.end());
   Node a = ita->second;
   Assert(!a.isNull());
-  if( std::find( d_cache[a].d_search_terms[tn][d].begin(), d_cache[a].d_search_terms[tn][d].end(), n )==d_cache[a].d_search_terms[tn][d].end() ){
+  SearchCache& sca = d_cache[a];
+  if (std::find(
+          sca.d_search_terms[tn][d].begin(), sca.d_search_terms[tn][d].end(), n)
+      == sca.d_search_terms[tn][d].end())
+  {
     Trace("sygus-sb-debug") << "  register search term : " << n << " at depth " << d << ", type=" << tn << ", tl=" << topLevel << std::endl;
-    d_cache[a].d_search_terms[tn][d].push_back( n );
+    sca.d_search_terms[tn][d].push_back(n);
     if( !options::sygusSymBreakLazy() ){
       addSymBreakLemmasFor( tn, n, d, lemmas );
     }
@@ -1024,11 +1028,11 @@ Node SygusExtension::registerSearchValue(Node a,
   std::map<TypeNode, int> var_count;
   Node cnv = d_tds->canonizeBuiltin(nv, var_count);
   Trace("sygus-sb-debug") << "  ...canonized value is " << cnv << std::endl;
+  SearchCache& sca = d_cache[a];
   // must do this for all nodes, regardless of top-level
-  if (d_cache[a].d_search_val_proc.find(cnv)
-      == d_cache[a].d_search_val_proc.end())
+  if (sca.d_search_val_proc.find(cnv) == sca.d_search_val_proc.end())
   {
-    d_cache[a].d_search_val_proc.insert(cnv);
+    sca.d_search_val_proc.insert(cnv);
     // get the root (for PBE symmetry breaking)
     Assert(d_anchor_to_conj.find(a) != d_anchor_to_conj.end());
     quantifiers::SynthConjecture* aconj = d_anchor_to_conj[a];
@@ -1049,11 +1053,16 @@ Node SygusExtension::registerSearchValue(Node a,
           a, nv, dbzet, Node::null(), var_count, lemmas);
       return Node::null();
     }else{
+      std::unordered_map<Node, Node, NodeHashFunction>& scasv =
+          sca.d_search_val[tn];
+      std::unordered_map<Node, unsigned, NodeHashFunction>& scasvs =
+          sca.d_search_val_sz[tn];
       std::unordered_map<Node, Node, NodeHashFunction>::iterator itsv =
-          d_cache[a].d_search_val[tn].find(bvr);
+          scasv.find(bvr);
       Node bad_val_bvr;
       bool by_examples = false;
-      if( itsv==d_cache[a].d_search_val[tn].end() ){
+      if (itsv == scasv.end())
+      {
         // Is it equivalent under examples?
         Node bvr_equiv;
         if (options::sygusSymBreakPbe())
@@ -1063,17 +1072,17 @@ Node SygusExtension::registerSearchValue(Node a,
           quantifiers::ExampleEvalCache* eec = aconj->getExampleEvalCache(a);
           if (eec != nullptr)
           {
-            bvr_equiv = eec->addSearchVal(bvr);
+            bvr_equiv = eec->addSearchVal(tn, bvr);
           }
         }
         if( !bvr_equiv.isNull() ){
           if( bvr_equiv!=bvr ){
             Trace("sygus-sb-debug") << "......adding search val for " << bvr << " returned " << bvr_equiv << std::endl;
-            Assert(d_cache[a].d_search_val[tn].find(bvr_equiv)
-                   != d_cache[a].d_search_val[tn].end());
-            Trace("sygus-sb-debug") << "......search value was " << d_cache[a].d_search_val[tn][bvr_equiv] << std::endl;
+            Assert(scasv.find(bvr_equiv) != scasv.end());
+            Trace("sygus-sb-debug")
+                << "......search value was " << scasv[bvr_equiv] << std::endl;
             if( Trace.isOn("sygus-sb-exc") ){
-              Node prev = d_tds->sygusToBuiltin( d_cache[a].d_search_val[tn][bvr_equiv], tn );
+              Node prev = d_tds->sygusToBuiltin(scasv[bvr_equiv], tn);
               Trace("sygus-sb-exc") << "  ......programs " << prev << " and " << bv << " are equivalent up to examples." << std::endl;
             }
             bad_val_bvr = bvr_equiv;
@@ -1081,14 +1090,16 @@ Node SygusExtension::registerSearchValue(Node a,
           }
         }
         //store rewritten values, regardless of whether it will be considered
-        d_cache[a].d_search_val[tn][bvr] = nv;
-        d_cache[a].d_search_val_sz[tn][bvr] = sz;
-      }else{
+        scasv[bvr] = nv;
+        scasvs[bvr] = sz;
+      }
+      else
+      {
         bad_val_bvr = bvr;
         if( Trace.isOn("sygus-sb-exc") ){
           Node prev_bv = d_tds->sygusToBuiltin( itsv->second, tn );
           Trace("sygus-sb-exc") << "  ......programs " << prev_bv << " and " << bv << " rewrite to " << bvr << "." << std::endl;
-        } 
+        }
       }
 
       if (options::sygusRewVerify())
@@ -1111,16 +1122,15 @@ Node SygusExtension::registerSearchValue(Node a,
 
       if( !bad_val_bvr.isNull() ){
         Node bad_val = nv;
-        Node bad_val_o = d_cache[a].d_search_val[tn][bad_val_bvr];
-        Assert(d_cache[a].d_search_val_sz[tn].find(bad_val_bvr)
-               != d_cache[a].d_search_val_sz[tn].end());
-        unsigned prev_sz = d_cache[a].d_search_val_sz[tn][bad_val_bvr];
+        Node bad_val_o = scasv[bad_val_bvr];
+        Assert(scasvs.find(bad_val_bvr) != scasvs.end());
+        unsigned prev_sz = scasvs[bad_val_bvr];
         bool doFlip = (prev_sz > sz);
         if (doFlip)
         {
           //swap : the excluded value is the previous
-          d_cache[a].d_search_val_sz[tn][bad_val_bvr] = sz;
-          bad_val = d_cache[a].d_search_val[tn][bad_val_bvr];
+          scasvs[bad_val_bvr] = sz;
+          bad_val = scasv[bad_val_bvr];
           bad_val_o = nv;
           if (Trace.isOn("sygus-sb-exc"))
           {
@@ -1196,14 +1206,17 @@ void SygusExtension::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz,
   Trace("sygus-sb-debug") << "     type : " << tn << std::endl;
   Trace("sygus-sb-debug") << "     size : " << sz << std::endl;
   Assert(!a.isNull());
-  d_cache[a].d_sb_lemmas[tn][sz].push_back( lem );
+  SearchCache& sca = d_cache[a];
+  sca.d_sb_lemmas[tn][sz].push_back(lem);
   TNode x = getFreeVar( tn );
   unsigned csz = getSearchSizeForAnchor( a );
   int max_depth = ((int)csz)-((int)sz);
   NodeManager* nm = NodeManager::currentNM();
   for( int d=0; d<=max_depth; d++ ){
-    std::map< unsigned, std::vector< Node > >::iterator itt = d_cache[a].d_search_terms[tn].find( d );
-    if( itt!=d_cache[a].d_search_terms[tn].end() ){
+    std::map<unsigned, std::vector<Node>>::iterator itt =
+        sca.d_search_terms[tn].find(d);
+    if (itt != sca.d_search_terms[tn].end())
+    {
       for (const TNode& t : itt->second)
       {
         if (!options::sygusSymBreakLazy()
@@ -1233,10 +1246,13 @@ void SygusExtension::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node
   Assert(!a.isNull());
   Trace("sygus-sb-debug2") << "add sym break lemmas for " << t << " " << d
                            << " " << a << std::endl;
-  std::map< TypeNode, std::map< unsigned, std::vector< Node > > >::iterator its = d_cache[a].d_sb_lemmas.find( tn );
+  SearchCache& sca = d_cache[a];
+  std::map<TypeNode, std::map<unsigned, std::vector<Node>>>::iterator its =
+      sca.d_sb_lemmas.find(tn);
   Node rlv = getRelevancyCondition(t);
   NodeManager* nm = NodeManager::currentNM();
-  if( its != d_cache[a].d_sb_lemmas.end() ){
+  if (its != sca.d_sb_lemmas.end())
+  {
     TNode x = getFreeVar( tn );
     //get symmetry breaking lemmas for this term 
     unsigned csz = getSearchSizeForAnchor( a );
index 2a43335d72c53b0f9117d88f9715f9c7ad8afbf2..ee539497206dbcb52d5a01a61f0eafc5c244ddc4 100644 (file)
@@ -43,7 +43,7 @@ ExampleEvalCache::ExampleEvalCache(TermDbSygus* tds,
 
 ExampleEvalCache::~ExampleEvalCache() {}
 
-Node ExampleEvalCache::addSearchVal(Node bv)
+Node ExampleEvalCache::addSearchVal(TypeNode tn, Node bv)
 {
   if (!d_indexSearchVals)
   {
@@ -53,7 +53,7 @@ Node ExampleEvalCache::addSearchVal(Node bv)
   std::vector<Node> vals;
   evaluateVec(bv, vals, true);
   Trace("sygus-pbe-debug") << "Add to trie..." << std::endl;
-  Node ret = d_trie.addOrGetTerm(bv, vals);
+  Node ret = d_trie[tn].addOrGetTerm(bv, vals);
   Trace("sygus-pbe-debug") << "...got " << ret << std::endl;
   // Only save the cache data if necessary: if the enumerated term
   // is redundant, its cached data will not be used later and thus should
index 6aa78851f93ef25edefc71f33cad1a27b514c446..6c5a65e7fcdeabf8ae8621ba1825e33f4cc223fe 100644 (file)
@@ -81,6 +81,10 @@ class ExampleEvalCache
    * procedure for SyGuS datatypes or the SyGuS fast enumerator when we are
    * considering a value of enumerator e passed to the constructor of this
    * class whose analog in the signature of builtin theory is bvr.
+   * 
+   * The type tn passed to this function is the sygus type of the term whose
+   * builtin equivalent is bvr. Terms with distinct types must be cached
+   * independently since two sygus types may not generate the same terms.
    *
    * This returns either:
    * - A term that is equivalent to bvr up to examples that was passed as the
@@ -94,7 +98,7 @@ class ExampleEvalCache
    * result of the evaluation of bvr is cached by this class, and can be
    * later accessed by evaluateVec below.
    */
-  Node addSearchVal(Node bvr);
+  Node addSearchVal(TypeNode tn, Node bvr);
   //----------------------------------- evaluating terms
   /** Evaluate vector
    *
@@ -149,7 +153,7 @@ class ExampleEvalCache
    * This is used for symmetry breaking in quantifier-free reasoning
    * about SyGuS datatypes.
    */
-  NodeTrie d_trie;
+  std::map< TypeNode, NodeTrie> d_trie;
   /** cache for evaluate */
   std::map<Node, std::vector<Node>> d_exOutCache;
 };
index 551a9d225a86275895a2a0485098a2ca51ab353a..3923361b1dbae4ac3ef61cd5c07a9e2908d1c096 100644 (file)
@@ -346,7 +346,7 @@ bool SygusEnumerator::TermCache::addTerm(Node n)
     {
       ++(d_stats->d_enumTermsExampleEval);
       // Is it equivalent under examples?
-      Node bne = d_eec->addSearchVal(bnr);
+      Node bne = d_eec->addSearchVal(d_tn, bnr);
       if (!bne.isNull())
       {
         if (bnr != bne)